Fix combine mlir script

This commit is contained in:
Vivek Khandelwal
2023-08-24 13:26:27 +00:00
parent 79075a1a07
commit 16160d9a7d

View File

@@ -189,25 +189,28 @@ class VicunaBase(SharkLLMBase):
return vicuna_model
def combine_mlir_scripts(
self, first_vicuna_mlir, second_vicuna_mlir, output_name, model_name=None
self,
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
model_name=None,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBIG] output_name = {output_name}")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants_1 = set()
constants_2 = set()
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vircuna mlir")
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants_1.add(line)
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
@@ -224,7 +227,7 @@ class VicunaBase(SharkLLMBase):
for func_line in f1
]
print(f"[DEBUG] processing second vircuna mlir")
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
@@ -233,7 +236,7 @@ class VicunaBase(SharkLLMBase):
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants_2.add(line)
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
@@ -256,25 +259,15 @@ class VicunaBase(SharkLLMBase):
module_end = "}"
global_vars = []
global_var_loading1 = dict()
global_var_loading2 = dict()
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
# in both 1 and 2
constants = [(e, "") for e in list(constants_1 & constants_2)]
# only in 1
constants.extend(
[(e, "_1") for e in list(constants_1.difference(constants_2))]
)
# only in 2
constants.extend(
[(e, "_2") for e in list(constants_2.difference(constants_1))]
)
del constants_1, constants_2
gc.collect()
counter = 0
constants = list(constants)
while constants:
constant, vname_suf = constants.pop(0)
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
@@ -285,33 +278,34 @@ class VicunaBase(SharkLLMBase):
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"util.global private @{vname}{vname_suf} {noinline} = {vbody} : {fixed_vdtype}"
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
else:
global_vars.append(
f"util.global private @{vname}{vname_suf} = {vbody} : i1"
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : i1"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : i1"
] = ""
del constants
gc.collect()
new_f1, new_f2 = [], []
@@ -319,7 +313,7 @@ class VicunaBase(SharkLLMBase):
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1.keys():
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
@@ -328,7 +322,7 @@ class VicunaBase(SharkLLMBase):
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2.keys():
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
@@ -1407,7 +1401,9 @@ class UnshardedVicuna(VicunaBase):
def compile(self, download_vmfb=False):
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
print(f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}")
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
)
if not self.vicuna_vmfb_path.exists() and download_vmfb:
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
@@ -1637,13 +1633,15 @@ class UnshardedVicuna(VicunaBase):
str(second_module)
)
if self.cache_vicunas:
with open(f"second_{self.precision}.mlir", 'w') as f:
with open(f"second_{self.precision}.mlir", "w") as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module, second_module, self.vicuna_mlir_path, self.model_name
first_module,
second_module,
self.vicuna_mlir_path,
self.model_name,
)
del first_module, second_module