Add Langchain SHARK Compilation support for all paths

This commit is contained in:
Vivek Khandelwal
2023-07-25 12:11:45 +00:00
parent 47f8a79c75
commit 927b662aa7

View File

@@ -1,5 +1,7 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
@@ -20,8 +22,38 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
global_device = "cuda"
global_precision = "fp16"
@@ -31,6 +63,67 @@ if not args.run_docuchat_web:
tensor_device = "cpu" if args.device == "cpu" else "cuda"
class H2OGPTModel(torch.nn.Module):
def __init__(self, device, precision):
super().__init__()
torch_dtype = (
torch.float32
if precision == "fp32" or device == "cpu"
else torch.float16
)
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
model_kwargs = {
"local_files_only": False,
"torch_dtype": torch_dtype,
"resume_download": True,
"use_auth_token": False,
"trust_remote_code": True,
"offload_folder": "offline_folder",
"device_map": device_map,
}
config = AutoConfig.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
use_auth_token=False,
trust_remote_code=True,
offload_folder="offline_folder",
)
self.model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
config=config,
**model_kwargs,
)
if precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model.transformer.h,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return output.logits[:, -1, :]
class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -42,47 +135,48 @@ class H2OGPTSHARKModel(torch.nn.Module):
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None
need_to_compile = False
if not vmfb_path.exists():
if args.device in ["cuda", "cpu"] and args.precision in [
"fp16",
"fp32",
]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
need_to_compile = True
# Downloading VMFB from shark_tank
print("Trying to download pre-compiled vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
if vmfb_path.exists():
print(
"Pre-compiled vmfb downloaded from shark tank successfully."
)
need_to_compile = False
if need_to_compile:
if not mlir_path.exists():
print("Trying to download pre-generated mlir from shark tank.")
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
# Generating the mlir
bytecode = self.get_bytecode(tensor_device, args.precision)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
if shark_module is None:
if vmfb_path.exists():
@@ -97,6 +191,72 @@ class H2OGPTSHARKModel(torch.nn.Module):
self.model = shark_module
def get_bytecode(self, device, precision):
h2ogpt_model = H2OGPTModel(device, precision)
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 400)
).to(device=device)
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
device=device
)
h2ogptCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
h2ogpt_model,
h2ogptCompileInput,
is_f16=False,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del h2ogpt_model
del self.src_model
print(f"[DEBUG] generating torch mlir")
if precision in ["int4", "int8"]:
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
)
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
return bytecode
def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(