Fix for Langchain (#1694)

For CPU, remove max time stopping criteria
Fix web UI issue
This commit is contained in:
Vivek Khandelwal
2023-07-26 21:30:23 +05:30
committed by GitHub
parent 9d399eb988
commit 776a9c2293
4 changed files with 27 additions and 10 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py

View File

@@ -30,7 +30,15 @@ from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -39,20 +47,30 @@ def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
global_device = "cuda"
global_precision = "fp16"
@@ -541,6 +559,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
return next_token
def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(

View File

@@ -1,12 +1,10 @@
# for generate (gradio server) and finetune
datasets==2.13.0
sentencepiece==0.1.99
# gradio==3.37.0
huggingface_hub==0.16.4
appdirs==1.4.4
fire==0.5.0
docutils==0.20.1
# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
@@ -21,7 +19,7 @@ bitsandbytes==0.39.0
accelerate==0.20.3
peft==0.4.0
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
# transformers==4.30.2
transformers==4.30.2
tokenizers==0.13.3
APScheduler==3.10.1
@@ -67,7 +65,7 @@ tiktoken==0.4.0
openai==0.27.8
# optional for chat with PDF
langchain==0.0.235
langchain==0.0.202
pypdf==3.12.2
# avoid textract, requires old six
#textract==1.6.5

View File

@@ -244,7 +244,7 @@ if __name__ == "__main__":
upscaler_status,
]
)
with gr.TabItem(label="DocuChat(Experimental)", id=9):
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()
# send to buttons