mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-06 20:43:50 -05:00
Fix for Langchain (#1694)
For CPU, remove max time stopping criteria Fix web UI issue
This commit is contained in:
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
|
||||
@@ -30,7 +30,15 @@ from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -39,20 +47,30 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
@@ -541,6 +559,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
return next_token
|
||||
|
||||
def generate_token(self, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
self.truncated_input_ids = []
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# for generate (gradio server) and finetune
|
||||
datasets==2.13.0
|
||||
sentencepiece==0.1.99
|
||||
# gradio==3.37.0
|
||||
huggingface_hub==0.16.4
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
@@ -21,7 +19,7 @@ bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
peft==0.4.0
|
||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||
# transformers==4.30.2
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
@@ -67,7 +65,7 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.235
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
@@ -244,7 +244,7 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=9):
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=10):
|
||||
h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
|
||||
Reference in New Issue
Block a user