mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
234 lines
8.0 KiB
Python
234 lines
8.0 KiB
Python
import gradio as gr
|
|
import json
|
|
import jsonlines
|
|
import os
|
|
from args import args
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
from utils import get_datasets
|
|
|
|
|
|
shark_root = Path(__file__).parent.parent
|
|
demo_css = shark_root.joinpath("web/demo.css").resolve()
|
|
nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/amd-logo.jpg")
|
|
|
|
|
|
with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
|
with gr.Row(elem_id="ui_title"):
|
|
nod_logo = Image.open(nodlogo_loc)
|
|
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
|
gr.Image(
|
|
value=nod_logo,
|
|
show_label=False,
|
|
interactive=False,
|
|
show_download_button=False,
|
|
elem_id="top_logo",
|
|
width=150,
|
|
height=100,
|
|
)
|
|
|
|
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
|
|
prompt_data = dict()
|
|
|
|
with gr.Row(elem_id="ui_body"):
|
|
# TODO: add multiselect dataset, there is a gradio version conflict
|
|
dataset = gr.Dropdown(label="Dataset", choices=datasets)
|
|
image_name = gr.Dropdown(label="Image", choices=[])
|
|
|
|
with gr.Row(elem_id="ui_body"):
|
|
# TODO: add ability to search image by typing
|
|
with gr.Column(scale=1, min_width=600):
|
|
image = gr.Image(type="filepath", height=512)
|
|
|
|
with gr.Column(scale=1, min_width=600):
|
|
prompts = gr.Dropdown(
|
|
label="Prompts",
|
|
choices=[],
|
|
)
|
|
prompt = gr.Textbox(
|
|
label="Editor",
|
|
lines=3,
|
|
)
|
|
with gr.Row():
|
|
save = gr.Button("Save")
|
|
delete = gr.Button("Delete")
|
|
with gr.Row():
|
|
back_image = gr.Button("Back")
|
|
next_image = gr.Button("Next")
|
|
finish = gr.Button("Finish")
|
|
|
|
def filter_datasets(dataset):
|
|
if dataset is None:
|
|
return gr.Dropdown.update(value=None, choices=[])
|
|
|
|
# create the dataset dir if doesn't exist and download prompt file
|
|
dataset_path = str(shark_root) + "/dataset/" + dataset
|
|
if not os.path.exists(dataset_path):
|
|
os.mkdir(dataset_path)
|
|
|
|
# read prompt jsonlines file
|
|
prompt_data.clear()
|
|
if dataset in ds_w_prompts:
|
|
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
|
|
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
|
|
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
|
|
for line in reader.iter(type=dict, skip_invalid=True):
|
|
prompt_data[line["file_name"]] = (
|
|
[line["text"]] if type(line["text"]) is str else line["text"]
|
|
)
|
|
|
|
return gr.Dropdown.update(choices=images[dataset])
|
|
|
|
dataset.change(fn=filter_datasets, inputs=dataset, outputs=image_name)
|
|
|
|
def display_image(dataset, image_name):
|
|
if dataset is None or image_name is None:
|
|
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
|
|
|
|
# download and load the image
|
|
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
|
|
img_sub_path = "/".join(image_name.split("/")[:-1])
|
|
img_dst_path = (
|
|
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
|
|
)
|
|
if not os.path.exists(img_dst_path):
|
|
os.mkdir(img_dst_path)
|
|
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
|
|
img = Image.open(img_dst_path + image_name.split("/")[-1])
|
|
|
|
if image_name not in prompt_data.keys():
|
|
prompt_data[image_name] = []
|
|
prompt_choices = ["Add new"]
|
|
prompt_choices += prompt_data[image_name]
|
|
return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices)
|
|
|
|
image_name.change(
|
|
fn=display_image,
|
|
inputs=[dataset, image_name],
|
|
outputs=[image, prompts],
|
|
)
|
|
|
|
def edit_prompt(prompts):
|
|
if prompts == "Add new":
|
|
return gr.Textbox.update(value=None)
|
|
|
|
return gr.Textbox.update(value=prompts)
|
|
|
|
prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt)
|
|
|
|
def save_prompt(dataset, image_name, prompts, prompt):
|
|
if dataset is None or image_name is None or prompts is None or prompt is None:
|
|
return
|
|
|
|
if prompts == "Add new":
|
|
prompt_data[image_name].append(prompt)
|
|
else:
|
|
idx = prompt_data[image_name].index(prompts)
|
|
prompt_data[image_name][idx] = prompt
|
|
|
|
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
|
|
# write prompt jsonlines file
|
|
with open(prompt_path, "w") as f:
|
|
for key, value in prompt_data.items():
|
|
if not value:
|
|
continue
|
|
v = value if len(value) > 1 else value[0]
|
|
f.write(json.dumps({"file_name": key, "text": v}))
|
|
f.write("\n")
|
|
|
|
prompt_choices = ["Add new"]
|
|
prompt_choices += prompt_data[image_name]
|
|
return gr.Dropdown.update(choices=prompt_choices, value=None)
|
|
|
|
save.click(
|
|
fn=save_prompt,
|
|
inputs=[dataset, image_name, prompts, prompt],
|
|
outputs=prompts,
|
|
)
|
|
|
|
def delete_prompt(dataset, image_name, prompts):
|
|
if dataset is None or image_name is None or prompts is None:
|
|
return
|
|
if prompts == "Add new":
|
|
return
|
|
|
|
prompt_data[image_name].remove(prompts)
|
|
prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl"
|
|
# write prompt jsonlines file
|
|
with open(prompt_path, "w") as f:
|
|
for key, value in prompt_data.items():
|
|
if not value:
|
|
continue
|
|
v = value if len(value) > 1 else value[0]
|
|
f.write(json.dumps({"file_name": key, "text": v}))
|
|
f.write("\n")
|
|
|
|
prompt_choices = ["Add new"]
|
|
prompt_choices += prompt_data[image_name]
|
|
return gr.Dropdown.update(choices=prompt_choices, value=None)
|
|
|
|
delete.click(
|
|
fn=delete_prompt,
|
|
inputs=[dataset, image_name, prompts],
|
|
outputs=prompts,
|
|
)
|
|
|
|
def get_back_image(dataset, image_name):
|
|
if dataset is None or image_name is None:
|
|
return
|
|
|
|
# remove local image
|
|
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
|
|
os.system(f'rm "{img_path}"')
|
|
# get the index for the back image
|
|
idx = images[dataset].index(image_name)
|
|
if idx == 0:
|
|
return gr.Dropdown.update(value=None)
|
|
|
|
return gr.Dropdown.update(value=images[dataset][idx - 1])
|
|
|
|
back_image.click(
|
|
fn=get_back_image, inputs=[dataset, image_name], outputs=image_name
|
|
)
|
|
|
|
def get_next_image(dataset, image_name):
|
|
if dataset is None or image_name is None:
|
|
return
|
|
|
|
# remove local image
|
|
img_path = str(shark_root) + "/dataset/" + dataset + "/" + image_name
|
|
os.system(f'rm "{img_path}"')
|
|
# get the index for the next image
|
|
idx = images[dataset].index(image_name)
|
|
if idx == len(images[dataset]) - 1:
|
|
return gr.Dropdown.update(value=None)
|
|
|
|
return gr.Dropdown.update(value=images[dataset][idx + 1])
|
|
|
|
next_image.click(
|
|
fn=get_next_image, inputs=[dataset, image_name], outputs=image_name
|
|
)
|
|
|
|
def finish_annotation(dataset):
|
|
if dataset is None:
|
|
return
|
|
|
|
# upload prompt and remove local data
|
|
dataset_path = str(shark_root) + "/dataset/" + dataset
|
|
dataset_gs_path = args.gs_url + "/" + dataset + "/"
|
|
os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"')
|
|
os.system(f'rm -rf "{dataset_path}"')
|
|
|
|
return gr.Dropdown.update(value=None)
|
|
|
|
finish.click(fn=finish_annotation, inputs=dataset, outputs=dataset)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
shark_web.launch(
|
|
share=args.share,
|
|
inbrowser=True,
|
|
server_name="0.0.0.0",
|
|
server_port=args.server_port,
|
|
)
|