Small cleanups.

- Quenched tokenizer warnings during model initialization.
- Changed "batch" to "iterations" for generating multiple images in
  order to conserve vram.
- Updated README.
- Moved static folder from under scripts to top level. Can store other
  static content there in future.
- Added screenshot of web server in action (to static folder).
This commit is contained in:
Lincoln Stein
2022-08-25 15:03:40 -04:00
parent 79add5f0b6
commit 2ada3288e7
4 changed files with 42 additions and 12 deletions

View File

@@ -1,11 +1,20 @@
import json
import base64
import os
from pytorch_lightning import logging
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
print("Loading model...")
from ldm.simplet2i import T2I
model = T2I()
model = T2I(sampler_name='k_lms')
# to get rid of annoying warning messages from pytorch
import transformers
transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
print("Initializing model, be patient...")
model.load_model()
class DreamServer(BaseHTTPRequestHandler):
def do_GET(self):
@@ -13,7 +22,7 @@ class DreamServer(BaseHTTPRequestHandler):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./scripts/static/index.html", "rb") as content:
with open("./static/index.html", "rb") as content:
self.wfile.write(content.read())
elif os.path.exists("." + self.path):
self.send_response(200)
@@ -33,7 +42,7 @@ class DreamServer(BaseHTTPRequestHandler):
post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt']
initimg = post_data['initimg']
batch = int(post_data['batch'])
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
@@ -46,7 +55,7 @@ class DreamServer(BaseHTTPRequestHandler):
if initimg is None:
# Run txt2img
outputs = model.txt2img(prompt,
batch_size = batch,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
@@ -61,7 +70,7 @@ class DreamServer(BaseHTTPRequestHandler):
# Run img2img
outputs = model.img2img(prompt,
init_img = "./img2img-tmp.png",
batch_size = batch,
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps)
@@ -77,7 +86,7 @@ class DreamServer(BaseHTTPRequestHandler):
if __name__ == "__main__":
dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer)
print("Started Stable Diffusion dream server!")
print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *")
try:
dream_server.serve_forever()