mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
Fix for Langchain (#1694)
For CPU, remove max time stopping criteria Fix web UI issue
This commit is contained in:
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
|||||||
count = 1
|
count = 1
|
||||||
show-source = 1
|
show-source = 1
|
||||||
select = E9,F63,F7,F82
|
select = E9,F63,F7,F82
|
||||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py
|
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||||
|
|||||||
@@ -30,7 +30,15 @@ from typing import List, Tuple
|
|||||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||||
|
|
||||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
|
||||||
|
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||||
|
lhs: List[int],
|
||||||
|
rhs: List[int],
|
||||||
|
rhs_scale: List[int],
|
||||||
|
rhs_zero_point: List[int],
|
||||||
|
rhs_bit_width: int,
|
||||||
|
rhs_group_size: int,
|
||||||
|
) -> List[int]:
|
||||||
if len(lhs) == 3 and len(rhs) == 2:
|
if len(lhs) == 3 and len(rhs) == 2:
|
||||||
return [lhs[0], lhs[1], rhs[0]]
|
return [lhs[0], lhs[1], rhs[0]]
|
||||||
elif len(lhs) == 2 and len(rhs) == 2:
|
elif len(lhs) == 2 and len(rhs) == 2:
|
||||||
@@ -39,20 +47,30 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh
|
|||||||
raise ValueError("Input shapes not supported.")
|
raise ValueError("Input shapes not supported.")
|
||||||
|
|
||||||
|
|
||||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||||
|
lhs_rank_dtype: Tuple[int, int],
|
||||||
|
rhs_rank_dtype: Tuple[int, int],
|
||||||
|
rhs_scale_rank_dtype: Tuple[int, int],
|
||||||
|
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||||
|
rhs_bit_width: int,
|
||||||
|
rhs_group_size: int,
|
||||||
|
) -> int:
|
||||||
# output dtype is the dtype of the lhs float input
|
# output dtype is the dtype of the lhs float input
|
||||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||||
return lhs_dtype
|
return lhs_dtype
|
||||||
|
|
||||||
|
|
||||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||||
|
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||||
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
brevitas_matmul_rhs_group_quant_library = [
|
brevitas_matmul_rhs_group_quant_library = [
|
||||||
brevitas〇matmul_rhs_group_quant〡shape,
|
brevitas〇matmul_rhs_group_quant〡shape,
|
||||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||||
|
]
|
||||||
|
|
||||||
global_device = "cuda"
|
global_device = "cuda"
|
||||||
global_precision = "fp16"
|
global_precision = "fp16"
|
||||||
@@ -541,6 +559,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|||||||
return next_token
|
return next_token
|
||||||
|
|
||||||
def generate_token(self, **generate_kwargs):
|
def generate_token(self, **generate_kwargs):
|
||||||
|
del generate_kwargs["max_time"]
|
||||||
self.truncated_input_ids = []
|
self.truncated_input_ids = []
|
||||||
|
|
||||||
generation_config_ = GenerationConfig.from_model_config(
|
generation_config_ = GenerationConfig.from_model_config(
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
# for generate (gradio server) and finetune
|
# for generate (gradio server) and finetune
|
||||||
datasets==2.13.0
|
datasets==2.13.0
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
# gradio==3.37.0
|
|
||||||
huggingface_hub==0.16.4
|
huggingface_hub==0.16.4
|
||||||
appdirs==1.4.4
|
appdirs==1.4.4
|
||||||
fire==0.5.0
|
fire==0.5.0
|
||||||
docutils==0.20.1
|
docutils==0.20.1
|
||||||
# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
|
|
||||||
evaluate==0.4.0
|
evaluate==0.4.0
|
||||||
rouge_score==0.1.2
|
rouge_score==0.1.2
|
||||||
sacrebleu==2.3.1
|
sacrebleu==2.3.1
|
||||||
@@ -21,7 +19,7 @@ bitsandbytes==0.39.0
|
|||||||
accelerate==0.20.3
|
accelerate==0.20.3
|
||||||
peft==0.4.0
|
peft==0.4.0
|
||||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||||
# transformers==4.30.2
|
transformers==4.30.2
|
||||||
tokenizers==0.13.3
|
tokenizers==0.13.3
|
||||||
APScheduler==3.10.1
|
APScheduler==3.10.1
|
||||||
|
|
||||||
@@ -67,7 +65,7 @@ tiktoken==0.4.0
|
|||||||
openai==0.27.8
|
openai==0.27.8
|
||||||
|
|
||||||
# optional for chat with PDF
|
# optional for chat with PDF
|
||||||
langchain==0.0.235
|
langchain==0.0.202
|
||||||
pypdf==3.12.2
|
pypdf==3.12.2
|
||||||
# avoid textract, requires old six
|
# avoid textract, requires old six
|
||||||
#textract==1.6.5
|
#textract==1.6.5
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ if __name__ == "__main__":
|
|||||||
upscaler_status,
|
upscaler_status,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with gr.TabItem(label="DocuChat(Experimental)", id=9):
|
with gr.TabItem(label="DocuChat(Experimental)", id=10):
|
||||||
h2ogpt_web.render()
|
h2ogpt_web.render()
|
||||||
|
|
||||||
# send to buttons
|
# send to buttons
|
||||||
|
|||||||
Reference in New Issue
Block a user