mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
fixed a bug where designating device for vicuna didn't work
This commit is contained in:
committed by
Gaurav Shukla
parent
fb865f1b99
commit
6f9f868fc0
@@ -35,7 +35,7 @@ class Vicuna(SharkLLMBase):
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -126,7 +126,7 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_to_vmfb(self, inputs, layers, is_first=True):
|
||||
def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
|
||||
mlirs, modules = [], []
|
||||
for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"):
|
||||
if is_first:
|
||||
@@ -210,19 +210,6 @@ class Vicuna(SharkLLMBase):
|
||||
|
||||
else:
|
||||
module = self.write_in_dynamic_inputs1(str(module), 138)
|
||||
if idx in [0, 5, 6, 7]:
|
||||
module_str = module
|
||||
module_str = module_str.splitlines()
|
||||
new_lines = []
|
||||
for line in module_str:
|
||||
if len(line) < 1000:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(line[:999])
|
||||
module_str = "\n".join(new_lines)
|
||||
f1_ = open(f"{idx}_1_test.mlir", "w+")
|
||||
f1_.write(module_str)
|
||||
f1_.close()
|
||||
|
||||
bytecode = module.encode("UTF-8")
|
||||
bytecode_stream = BytesIO(bytecode)
|
||||
@@ -236,20 +223,22 @@ class Vicuna(SharkLLMBase):
|
||||
for idx, layer in tqdm(enumerate(layers), desc="compiling modules"):
|
||||
if is_first:
|
||||
vmfb_path = Path(f"{idx}_0.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_0",
|
||||
@@ -264,20 +253,22 @@ class Vicuna(SharkLLMBase):
|
||||
modules.append(module)
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
None, device=device, mlir_dialect="tm_tensor"
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
module = SharkInference(
|
||||
mlirs[idx], device=device, mlir_dialect="tm_tensor"
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 1,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_1",
|
||||
@@ -293,7 +284,7 @@ class Vicuna(SharkLLMBase):
|
||||
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self):
|
||||
def get_sharded_model(self, device="cpu"):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
@@ -316,7 +307,10 @@ class Vicuna(SharkLLMBase):
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules0 = self.compile_to_vmfb(
|
||||
placeholder_input0, layers0, is_first=True
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
is_first=True,
|
||||
device=device,
|
||||
)
|
||||
shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0]
|
||||
|
||||
@@ -324,7 +318,7 @@ class Vicuna(SharkLLMBase):
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules1 = self.compile_to_vmfb(
|
||||
placeholder_input1, layers1, is_first=False
|
||||
placeholder_input1, layers1, is_first=False, device=device
|
||||
)
|
||||
shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1]
|
||||
|
||||
@@ -333,8 +327,8 @@ class Vicuna(SharkLLMBase):
|
||||
)
|
||||
return sharded_model
|
||||
|
||||
def compile(self):
|
||||
return self.get_sharded_model()
|
||||
def compile(self, device="cpu"):
|
||||
return self.get_sharded_model(device=device)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
Reference in New Issue
Block a user