codegen API (#1655)

This commit is contained in:
Phaneesh Barwaria
2023-07-17 08:30:39 +05:30
committed by GitHub
parent a2a436eb0c
commit c471d17cca
4 changed files with 120 additions and 3 deletions

View File

@@ -1373,7 +1373,7 @@ class UnshardedVicuna(SharkLLMBase):
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
yield detok
# yield detok
res_tokens.append(token)
if cli:

View File

@@ -50,7 +50,9 @@ if __name__ == "__main__":
upscaler_api,
inpaint_api,
outpaint_api,
llm_chat_api,
)
from fastapi import FastAPI, APIRouter
import uvicorn
@@ -63,8 +65,19 @@ if __name__ == "__main__":
app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
app.add_api_route(
"/v1/chat/completions", llm_chat_api, methods=["post"]
)
app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
app.add_api_route("/completions", llm_chat_api, methods=["post"])
app.add_api_route(
"/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
)
app.include_router(APIRouter())
uvicorn.run(app, host="127.0.0.1", port=args.server_port)
uvicorn.run(app, host="0.0.0.0", port=args.server_port)
sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any

View File

@@ -74,7 +74,10 @@ from apps.stable_diffusion.web.ui.model_manager import (
modelmanager_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web
from apps.stable_diffusion.web.ui.stablelm_ui import stablelm_chat
from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,
outputgallery_tab_select,

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
def user(message, history):
@@ -73,6 +74,7 @@ def create_prompt(model_name, history):
return msg
# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision):
global sharded_model
global past_key_values
@@ -147,6 +149,105 @@ def chat(curr_system_message, history, model, device, precision):
return words_list
def llm_chat_api(InputData: dict):
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
with gr.Blocks(title="Chatbot") as stablelm_chat:
with gr.Row():
model_choices = list(