Pin iree versions and fix quant matmul flags. (#2055)

Restricts quantized matmul reassociation flags to cpu compiles of llama2 and pins IREE versions for shark 1.0
This commit is contained in:
Ean Garvey
2024-01-04 14:22:54 -06:00
committed by GitHub
parent 3887d83f5d
commit 6853a33728
5 changed files with 6 additions and 5 deletions

View File

@@ -2075,6 +2075,10 @@ class UnshardedVicuna(VicunaBase):
f"Compiling for device : {self.device}"
f"{'://' + str(self.device_id) if self.device_id is not None else ''}"
)
if "cpu" in self.device:
self.extra_args.extend("--iree-llvmcpu-enable-quantized-matmul-reassociation")
self.extra_args.extend("--iree-global-opt-enable-quantized-matmul-reassociation")
shark_module = SharkInference(
mlir_module=combined_module,
device=self.device,

View File

@@ -177,8 +177,6 @@ def chat(
)
_extra_args = _extra_args + [
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]

View File

@@ -90,7 +90,7 @@ python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler==20231222.* iree-runtime==20231222.*
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"

View File

@@ -111,7 +111,7 @@ else
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler==20231222.* iree-runtime==20231222.*
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi

View File

@@ -43,7 +43,6 @@ def get_iree_device_args(device, extra_args=[]):
get_iree_cpu_args()
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args