mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Add Langchain SHARK Compilation support for all paths
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
@@ -20,8 +22,38 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
@@ -31,6 +63,67 @@ if not args.run_docuchat_web:
|
||||
tensor_device = "cpu" if args.device == "cpu" else "cuda"
|
||||
|
||||
|
||||
class H2OGPTModel(torch.nn.Module):
|
||||
def __init__(self, device, precision):
|
||||
super().__init__()
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if precision == "fp32" or device == "cpu"
|
||||
else torch.float16
|
||||
)
|
||||
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
|
||||
model_kwargs = {
|
||||
"local_files_only": False,
|
||||
"torch_dtype": torch_dtype,
|
||||
"resume_download": True,
|
||||
"use_auth_token": False,
|
||||
"trust_remote_code": True,
|
||||
"offload_folder": "offline_folder",
|
||||
"device_map": device_map,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder="offline_folder",
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
config=config,
|
||||
**model_kwargs,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model.transformer.h,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return output.logits[:, -1, :]
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -42,47 +135,48 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
need_to_compile = False
|
||||
if not vmfb_path.exists():
|
||||
if args.device in ["cuda", "cpu"] and args.precision in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
]:
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
need_to_compile = True
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Trying to download pre-compiled vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print(
|
||||
"Pre-compiled vmfb downloaded from shark tank successfully."
|
||||
)
|
||||
need_to_compile = False
|
||||
|
||||
if need_to_compile:
|
||||
if not mlir_path.exists():
|
||||
print("Trying to download pre-generated mlir from shark tank.")
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
# Generating the mlir
|
||||
bytecode = self.get_bytecode(tensor_device, args.precision)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
@@ -97,6 +191,72 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
def get_bytecode(self, device, precision):
|
||||
h2ogpt_model = H2OGPTModel(device, precision)
|
||||
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 400)
|
||||
).to(device=device)
|
||||
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
h2ogptCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
h2ogpt_model,
|
||||
h2ogptCompileInput,
|
||||
is_f16=False,
|
||||
precision=precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del h2ogpt_model
|
||||
del self.src_model
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if precision in ["int4", "int8"]:
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = torch.from_numpy(
|
||||
self.model(
|
||||
|
||||
Reference in New Issue
Block a user