Fix formatting and regex.

This commit is contained in:
Ean Garvey
2023-08-16 14:37:31 -05:00
parent ac01cfa5cc
commit e644fdf38a
3 changed files with 39 additions and 18 deletions

View File

@@ -1,6 +1,7 @@
import torch
from transformers import AutoModelForCausalLM
class FirstVicuna(torch.nn.Module):
def __init__(
self,
@@ -19,8 +20,10 @@ class FirstVicuna(torch.nn.Module):
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
@@ -65,7 +68,10 @@ class SecondVicuna(torch.nn.Module):
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(

View File

@@ -44,22 +44,31 @@ def get_iree_gpu_args():
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from hipinfo.
import os
import re
import subprocess
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"hipinfo", shell=True, text=True
),
).group(1)
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
@@ -69,13 +78,16 @@ def get_iree_rocm_args():
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
print("Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100.")
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
f"--iree-rocm-bc-dir={rocm_path}/amdgcn/bitcode",
f"--iree-rocm-bc-dir={bc_path}",
]

View File

@@ -583,7 +583,7 @@ def import_with_fx(
]
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
@@ -591,13 +591,16 @@ def import_with_fx(
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.export import (
replace_call_fn_target,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]