Change to a separate pass to unpack quantized weights (#1652)

This commit is contained in:
jinchen62
2023-07-15 04:54:53 -07:00
committed by GitHub
parent 5ec91143f5
commit e20cd71314

View File

@@ -577,7 +577,7 @@ class ShardedVicuna(SharkLLMBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module0,
"builtin.module(func.func(canonicalize),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -621,7 +621,7 @@ class ShardedVicuna(SharkLLMBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module1,
"builtin.module(func.func(canonicalize),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -969,7 +969,6 @@ class UnshardedVicuna(SharkLLMBase):
)
if not mlir_generated:
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
@@ -1019,7 +1018,7 @@ class UnshardedVicuna(SharkLLMBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(canonicalize),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -1187,7 +1186,7 @@ class UnshardedVicuna(SharkLLMBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(canonicalize),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else: