mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
fix combine mlir for llama2
This commit is contained in:
@@ -244,7 +244,8 @@ class VicunaBase(SharkLLMBase):
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
constants_1 = set()
|
||||
constants_2 = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
@@ -255,7 +256,7 @@ class VicunaBase(SharkLLMBase):
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
constants_1.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
@@ -281,7 +282,7 @@ class VicunaBase(SharkLLMBase):
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
constants_2.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
@@ -304,15 +305,21 @@ class VicunaBase(SharkLLMBase):
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
global_var_loading1 = dict()
|
||||
global_var_loading2 = dict()
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
# in both 1 and 2
|
||||
constants = [(e , "") for e in list(constants_1 & constants_2)]
|
||||
# only in 1
|
||||
constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))])
|
||||
# only in 2
|
||||
constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))])
|
||||
del constants_1, constants_2
|
||||
gc.collect()
|
||||
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
constant, vname_suf = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
@@ -322,35 +329,34 @@ class VicunaBase(SharkLLMBase):
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
if vname_suf != "_2":
|
||||
global_var_loading1[
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
|
||||
] = ""
|
||||
if vname_suf != "_1":
|
||||
global_var_loading2[
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
|
||||
] = ""
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1"
|
||||
)
|
||||
if vname_suf != "_2":
|
||||
global_var_loading1[
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
|
||||
] = ""
|
||||
if vname_suf != "_1":
|
||||
global_var_loading2[
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
|
||||
] = ""
|
||||
|
||||
del constants
|
||||
gc.collect()
|
||||
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
@@ -358,7 +364,7 @@ class VicunaBase(SharkLLMBase):
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
for global_var in global_var_loading1.keys():
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
@@ -367,7 +373,7 @@ class VicunaBase(SharkLLMBase):
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
for global_var in global_var_loading2.keys():
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
|
||||
Reference in New Issue
Block a user