changed method of compiling vicuna to remove first and second vicuna (#1611)

Co-authored-by: Elias Joseph <elias@nod-labs.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
This commit is contained in:
Eliasj42
2023-07-03 12:12:43 -07:00
committed by GitHub
parent d63ce76dd8
commit 4015793f84
2 changed files with 317 additions and 265 deletions

View File

@@ -62,104 +62,23 @@ class SecondVicunaLayer(torch.nn.Module):
)
class CompiledFirstVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class CompiledSecondVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers0, layers1, lmhead, embedding, norm):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
assert len(layers0) == len(model.model.layers)
# self.model.model.layers = torch.nn.modules.container.ModuleList(layers0)
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers0 = layers0
self.layers1 = layers1
self.layers = layers
self.norm = norm
self.embedding = embedding
self.lmhead = lmhead
self.model.model.norm = self.norm
self.model.model.embed_tokens = self.embedding
self.model.lm_head = self.lmhead
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers
)
def forward(
self,
@@ -168,20 +87,11 @@ class ShardedVicunaModel(torch.nn.Module):
past_key_values=None,
attention_mask=None,
):
if is_first:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers0
)
return self.model.forward(input_ids, attention_mask=attention_mask)
else:
self.model.model.layers = torch.nn.modules.container.ModuleList(
self.layers1
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return self.model.forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
class LMHead(torch.nn.Module):
@@ -248,3 +158,71 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
output = self.model("forward", (input_ids,))
output = torch.tensor(output)
return output
class CompiledVicunaLayer(torch.nn.Module):
def __init__(self, shark_module):
super().__init__()
self.model = shark_module
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
if past_key_value is None:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)
else:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv0 = past_key_value[0].detach()
pkv1 = past_key_value[1].detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv0,
pkv1,
),
)
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
return (
output0,
(
output1,
output2,
),
)

View File

@@ -1,8 +1,7 @@
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledFirstVicunaLayer,
CompiledSecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
@@ -95,6 +94,7 @@ class Vicuna(SharkLLMBase):
)
continue
line = re.sub(f"{dynamic_input_size}x", "?x", line)
line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line)
if "?x" in line:
line = re.sub(
"tensor.empty\(\)", "tensor.empty(%dim_42)", line
@@ -112,6 +112,124 @@ class Vicuna(SharkLLMBase):
new_module = "\n".join(new_lines)
return new_module
def combine_mlir_scripts(
self, first_vicuna_mlir, second_vicuna_mlir, output_name
):
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
for line in first_vicuna_mlir.splitlines():
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
for line in second_vicuna_mlir.splitlines():
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = (
'module attributes {torch.debug_module_name = "_lambda"} {'
)
module_end = "}"
global_vars = []
vnames = []
vdtypes = []
global_var_loading1 = []
global_var_loading2 = []
for constant in list(constants):
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
vdtype = vbody.split(":")[1].strip()
fixed_vdtype = vdtype
vdtypes.append(vdtype)
vdtype = re.sub("\d{1,}x", "?x", vdtype)
vnames.append(vname)
global_vars.append(
f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
new_f1, new_f2 = [], []
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading1:
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
whole_string = "\n".join(
maps1
+ maps2
+ [module_start]
+ global_vars
+ f1
+ f2
+ [module_end]
)
f_ = open(output_name, "w+")
f_.write(whole_string)
f_.close()
return whole_string
def compile_vicuna_layer(
self,
vicuna_layer,
@@ -193,6 +311,7 @@ class Vicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -235,6 +354,7 @@ class Vicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -276,6 +396,7 @@ class Vicuna(SharkLLMBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -286,171 +407,127 @@ class Vicuna(SharkLLMBase):
return compiled_module
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
# compile all layers for vmfb
# this needs to be run seperatley for first and second vicuna
def compile_to_vmfb_one_model(
self, inputs0, layers0, inputs1, layers1, device="cpu"
):
mlirs, modules = [], []
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
if is_first:
mlir_path = Path(f"{idx}_0.mlir")
vmfb_path = Path(f"{idx}_0.vmfb")
else:
mlir_path = Path(f"{idx}_1.mlir")
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
continue
assert len(layers0) == len(layers1)
for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))):
mlir_path = Path(f"{idx}_full.mlir")
vmfb_path = Path(f"{idx}_full.vmfb")
# if vmfb_path.exists():
# continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
mlirs.append(bytecode)
else:
hidden_states_placeholder = TensorPlaceholder.like(
inputs[0], dynamic_axes=[1]
hidden_states_placeholder0 = TensorPlaceholder.like(
inputs0[0], dynamic_axes=[1]
)
attention_mask_placeholder = TensorPlaceholder.like(
inputs[1], dynamic_axes=[3]
attention_mask_placeholder0 = TensorPlaceholder.like(
inputs0[1], dynamic_axes=[3]
)
position_ids_placeholder = TensorPlaceholder.like(
inputs[2], dynamic_axes=[1]
position_ids_placeholder0 = TensorPlaceholder.like(
inputs0[2], dynamic_axes=[1]
)
hidden_states_placeholder1 = TensorPlaceholder.like(
inputs1[0], dynamic_axes=[1]
)
attention_mask_placeholder1 = TensorPlaceholder.like(
inputs1[1], dynamic_axes=[3]
)
position_ids_placeholder1 = TensorPlaceholder.like(
inputs1[2], dynamic_axes=[1]
)
pkv0_placeholder = TensorPlaceholder.like(
inputs1[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs1[4], dynamic_axes=[2]
)
if not is_first:
pkv0_placeholder = TensorPlaceholder.like(
inputs[3], dynamic_axes=[2]
)
pkv1_placeholder = TensorPlaceholder.like(
inputs[4], dynamic_axes=[2]
)
print(f"Compiling layer {idx} mlir")
if is_first:
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
inputs[0],
inputs[1],
inputs[2],
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
ts_g = self.compile_vicuna_layer(
layer0, inputs0[0], inputs0[1], inputs0[2]
)
module0 = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder0,
inputs0[1],
inputs0[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
module0 = self.write_in_dynamic_inputs0(str(module0), 137)
if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
ts_g = self.compile_vicuna_layer(
layer1,
inputs1[0],
inputs1[1],
inputs1[2],
inputs1[3],
inputs1[4],
)
module1 = torch_mlir.compile(
ts_g,
(
inputs1[0],
attention_mask_placeholder1,
inputs1[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
module1 = self.write_in_dynamic_inputs1(str(module1), 138)
else:
module = self.write_in_dynamic_inputs1(str(module), 138)
module_combined = self.combine_mlir_scripts(
module0, module1, f"{idx}_full.mlir"
)
mlirs.append(module_combined)
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
f_.close()
mlirs.append(bytecode)
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
if is_first:
vmfb_path = Path(f"{idx}_0.vmfb")
if vmfb_path.exists():
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None,
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
module_name=f"{idx}_0",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None,
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
)
module.load_module(vmfb_path)
else:
vmfb_path = Path(f"{idx}_1.vmfb")
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None,
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"second_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
)
module.save_module(
module_name=f"{idx}_1",
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 4,
mlir_dialect="tm_tensor",
mmap=False,
)
module.save_module(
module_name=f"{idx}_full",
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules
def get_sharded_model(self, device="cpu"):
@@ -511,26 +588,23 @@ class Vicuna(SharkLLMBase):
layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules0 = self.compile_to_vmfb(
placeholder_input0,
layers0,
is_first=True,
device=device,
)
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
_, modules1 = self.compile_to_vmfb(
placeholder_input1, layers1, is_first=False, device=device
_, modules = self.compile_to_vmfb_one_model(
placeholder_input0,
layers0,
placeholder_input1,
layers1,
device=device,
)
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
shark_layers = [CompiledVicunaLayer(m) for m in modules]
sharded_model = ShardedVicunaModel(
vicuna_model,
shark_layers0,
shark_layers1,
shark_layers,
lmhead,
embeddings,
norm,