(WIP): Studio2 app infra and SD API

This commit is contained in:
Ean Garvey
2023-12-08 00:45:34 -06:00
parent ebfcfec338
commit 98fea8b19c
7 changed files with 761 additions and 328 deletions

102
apps/shark_studio/api/sd.py Normal file
View File

@@ -0,0 +1,102 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch
sd_model_map = {
"sd15": {
"base_model_id": "runwayml/stable-diffusion-v1-5"
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 77,
}
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
}
"vae_decode": {
"initializer": vae.export_vae_model,,
}
}
}
class SharkStableDiffusionPipeline:
def __init__(
self, model_name, , device=None, precision="fp32"
):
print(sd_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def generate_images(
self,
prompt,
):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -0,0 +1,255 @@
import base64
import io
import os
import time
import datetime
import uvicorn
import ipaddress
import requests
import gradio as gr
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from apps.shark_studio. import sd_samplers, postprocessing, errors, restart
from sdapi_v1 import shark_sd_api
from api.llm import chat_api
def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):
if not opts.api_enable_requests:
raise HTTPException(status_code=500, detail="Requests not allowed")
if opts.api_forbid_local_requests and not verify_url(encoding):
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
if image.mode == "RGBA":
image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
rich_available = False
try:
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
rich_available = True
except Exception:
pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
res: Response = await call_next(req)
duration = str(round(time.time() - ts, 4))
res.headers["X-Process-Time"] = duration
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get('http_version', '0.0'),
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
prot=req.scope.get('scheme', 'err'),
method=req.scope.get('method', 'err'),
endpoint=endpoint,
duration=duration,
))
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get('detail', ''),
"body": vars(e).get('body', ''),
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class ApiCompat:
def __init__(self, queue_lock: Lock):
self.router = APIRouter()
self.app = FastAPI()
self.queue_lock = queue_lock
api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"])
self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"])
#self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"])
#self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
#self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
#self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
#self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
#self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
#self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
#self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
#self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
#self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
#self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
#self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
#self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
#self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
#self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
#self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
#self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
#self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
#self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
#self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
#self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
#self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
#self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
#self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
#self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
#self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
#self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
#self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
#self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
#self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
# chat APIs needed for compatibility with multiple extensions using OpenAI API
self.add_api_route(
"/v1/chat/completions", chat_api, methods=["post"]
)
self.add_api_route("/v1/completions", chat_api, methods=["post"])
self.add_api_route("/chat/completions", chat_api, methods=["post"])
self.add_api_route("/completions", chat_api, methods=["post"])
self.add_api_route(
"/v1/engines/codegen/completions", chat_api, methods=["post"]
)
if studio.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path:str, endpoint, **kwargs):
if studio.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs
return self.app.add_api_route(path, endpoint, **kwargs)
def refresh_checkpoints(self):
with self.queue_lock:
studio_data.refresh_checkpoints()
def refresh_vae(self):
with self.queue_lock:
studio_data.refresh_vae_list()
def unloadapi(self):
unload_model_weights()
return {}
def reloadapi(self):
reload_model_weights()
return {}
def skip(self):
studio.state.skip()
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path)
def kill_studio(self):
restart.stop_program()
def restart_studio(self):
if restart.is_restartable():
restart.restart_program()
return Response(status_code=501)
def preprocess(self, args: dict):
try:
studio.state.begin(job="preprocess")
preprocess(**args)
studio.state.end()
return models.PreprocessResponse(info='preprocess complete')
except:
studio.state.end()
def stop_studio(request):
studio.state.server_command = "stop"
return Response("Stopping.")

View File

@@ -1,20 +1,54 @@
from multiprocessing import Process, freeze_support
import os
import time
import sys
import logging
from ui.chat import chat_element
from ui.sd import sd_element
from modules import timer, initialize
startup_timer = timer.startup_timer
startup_timer.record("launcher")
initialize.imports()
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def create_api(app):
from apps.shark_studio.api.compat import ApiCompat
from modules.call_queue import queue_lock
api = ApiCompat(app, queue_lock)
return api
def launch_app(address):
def api_only():
from fastapi import FastAPI
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
initialize.initialize()
app = FastAPI()
initialize.setup_middleware(app)
api = create_api(app)
#from modules import script_callbacks
#script_callbacks.before_ui_callback()
#script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
port=cmd_opts.port if cmd_opts.port else 8080,
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
)
def launch_webui(address):
from tkinter import Tk
import webview
@@ -33,63 +67,77 @@ def launch_app(address):
)
webview.start(private_mode=False, storage_path=os.getcwd())
def webui():
from apps.shark_studio.shared_cmd_options import cmd_opts
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
launch_api = cmd_opts.api
initialize.initialize()
from modules import shared, ui_tempdir, script_callbacks, ui, progress
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
# if args.api or "api" in args.ui.split(","):
# from apps.shark_studio.api.llm import (
# chat,
# )
# from apps.shark_studio.web.api import sdapi
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
# from fastapi import FastAPI, APIRouter
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
# # init global sd pipeline and config
# global_obj._init()
#
# api = FastAPI()
# api.mount("/sdapi/", sdapi)
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# api.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# api.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# api.add_api_route("/completions", llm_chat_api, methods=["post"])
# api.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# api.include_router(APIRouter())
#
# # deal with CORS requests if CORS accept origins are set
# if args.api_accept_origin:
# print(
# f"API Configured for CORS. Accepting origins: { args.api_accept_origin }"
# )
# api.add_middleware(
# CORSMiddleware,
# allow_origins=args.api_accept_origin,
# allow_methods=["GET", "POST"],
# allow_headers=["*"],
# )
# else:
# print("API not configured for CORS")
#
# uvicorn.run(api, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
from apps.shark_studio.web.initializers import (
config_gradio_tmp_imgs_folder,
create_custom_models_folders,
)
# config_gradio_tmp_imgs_folder()
config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
@@ -98,74 +146,10 @@ if __name__ == "__main__":
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
from apps.shark_studio.web.ui import load_ui_from_script
# init global sd pipeline and config
# global_obj._init()
studio.state._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
@@ -176,18 +160,6 @@ if __name__ == "__main__":
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
@@ -200,7 +172,7 @@ if __name__ == "__main__":
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
) as sd_web:
) as studio_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
@@ -211,216 +183,29 @@ if __name__ == "__main__":
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
with gr.TabItem(label="Stable Diffusion", id=0):
sd_element.render()
#with gr.TabItem(label="Output Gallery", id=1):
with gr.TabItem(label="Chat Bot", id=2):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
studio_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
studio_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)
if __name__ == "__main__":
from apps.shark_studio.shared_cmd_options import cmd_opts
if cmd_opts.nowebui:
api_only()
else:
webui()

View File

@@ -0,0 +1,145 @@
import importlib
import logging
import os
import signal
import sys
import re
import warnings
import json
from threading import Thread
from modules.timer import startup_timer
def imports():
import torch # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
import gradio # noqa: F401
startup_timer.record("import gradio")
from apps.shark_studio.modules import shared_init
shared_init.initialize()
startup_timer.record("initialize shared")
from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")
def initialize():
configure_sigint_handler()
configure_opts_onchange()
from apps.shark_studio.modules import modelloader
modelloader.cleanup_models()
from apps.shark_studio.modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
#from apps.shark_studio.modules.shared_cmd_options import cmd_opts
#from apps.shark_studio.modules import codeformer_model
#warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
#codeformer_model.setup_model(cmd_opts.codeformer_models_path)
#startup_timer.record("setup codeformer")
#from apps.shark_studio.modules import gfpgan_model
#gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
#startup_timer.record("setup gfpgan")
initialize_rest(reload_script_modules=False)
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
print("\n".join(code))
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
dumpstacks()
os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)
def dumpstacks():
import threading
import traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
print("\n".join(code))
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
from apps.shark_studio.modules.shared_cmd_options import cmd_opts
from apps.shark_studio.modules import sd_samplers
sd_samplers.set_samplers()
startup_timer.record("set samplers")
restore_config_state_file()
startup_timer.record("restore config state file")
from apps.shark_studio.modules import sd_models
sd_models.list_models()
startup_timer.record("list SD models")
with startup_timer.subcategory("load scripts"):
scripts.load_scripts()
if reload_script_modules:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")
from apps.shark_studio.modules import sd_vae
sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
# from apps.shark_studio.modules import textual_inversion
# textual_inversion.textual_inversion.list_textual_inversion_templates()
# startup_timer.record("refresh textual inversion templates")
from apps.shark_studio.modules import sd_unet
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.
"""
shared.sd_model # noqa: B018
Thread(target=load_model).start()

View File

@@ -0,0 +1,53 @@
import sys
import gradio as gr
from modules import shared_cmd_options, shared_gradio, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import util
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
config_filename = cmd_opts.ui_settings_file
demo = None
device = None
weight_load_location = None
state = None
prompt_styles = None
options_templates = None
opts = None
restricted_opts = None
sd_model: sd_models_types.WebuiSdModel = None
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
tab_names = []
sd_upscalers = []
clip_model = None
progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
total_tqdm = None
mem_mon = None
reload_gradio_theme = shared_gradio.reload_gradio_theme
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
refresh_checkpoints = shared_items.refresh_checkpoints
list_samplers = shared_items.list_samplers

View File

@@ -0,0 +1,91 @@
import time
import argparse
class TimerSubcategory:
def __init__(self, timer, category):
self.timer = timer
self.category = category
self.start = None
self.original_base_category = timer.base_category
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer:
def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ''
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
self.records[category] += amount
def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
if self.print_log and not disable_log:
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
def subcategory(self, name):
self.elapsed()
subcat = TimerSubcategory(self, name)
return subcat
def summary(self):
res = f"{self.total:.1f}s"
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
if not additions:
return res
res += " ("
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
res += ")"
return res
def dump(self):
return {'total': self.total, 'records': self.records}
def reset(self):
self.__init__()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None

View File

@@ -240,9 +240,11 @@ with gr.Blocks(title="Chat") as chat_element:
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)