Address TODOs for dataset annotator (#872)

- add args usage, pass gs_url by CL flag
- add support for no existing prompts
This commit is contained in:
jinchen62
2023-01-25 09:28:23 -08:00
committed by GitHub
parent aafe7c4701
commit c3a641f0ab
4 changed files with 72 additions and 31 deletions

View File

@@ -16,7 +16,7 @@ pip install -r requirements.txt
python annotation_tool.py
```
<img width="1308" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214191759-24cc5fe6-cd53-4099-87f6-707068f8888d.png">
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
* Select a dataset from `Dataset` dropdown list
* Select an image from `Image` dropdown list

View File

@@ -2,15 +2,12 @@ import gradio as gr
import json
import jsonlines
import os
from args import args
from pathlib import Path
from PIL import Image
from utils import get_datasets
# TODO: pass gs_url as a command line flag
# see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
gs_url = "gs://shark-datasets/portraits"
shark_root = Path(__file__).parent.parent
demo_css = shark_root.joinpath("web/demo.css").resolve()
nodlogo_loc = shark_root.joinpath(
@@ -30,15 +27,15 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
elem_id="top_logo",
).style(width=150, height=100)
datasets, images = get_datasets(gs_url)
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
prompt_data = dict()
with gr.Row(elem_id="ui_body"):
# TODO: add multiselect dataset
# TODO: add multiselect dataset, there is a gradio version conflict
dataset = gr.Dropdown(label="Dataset", choices=datasets)
image_name = gr.Dropdown(label="Image", choices=[])
with gr.Row(elem_id="ui_body", visible=True):
with gr.Row(elem_id="ui_body"):
# TODO: add ability to search image by typing
with gr.Column(scale=1, min_width=600):
image = gr.Image(type="filepath").style(height=512)
@@ -61,27 +58,26 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
finish = gr.Button("Finish")
def filter_datasets(dataset):
# TODO: execute finish process when switching dataset
if dataset is None:
return gr.Dropdown.update(value=None, choices=[])
# create the dataset dir if doesn't exist and download prompt file
dataset_path = str(shark_root) + "/dataset/" + dataset
# TODO: check if metadata.jsonl exists
prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
if not os.path.exists(dataset_path):
os.mkdir(dataset_path)
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
# read prompt jsonlines file
prompt_data.clear()
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
if dataset in ds_w_prompts:
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
for line in reader.iter(type=dict, skip_invalid=True):
prompt_data[line["file_name"]] = (
[line["text"]]
if type(line["text"]) is str
else line["text"]
)
return gr.Dropdown.update(choices=images[dataset])
@@ -92,8 +88,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
# download and load the image
# TODO: remove previous image if change image from dropdown
img_gs_path = gs_url + "/" + dataset + "/" + image_name
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
img_sub_path = "/".join(image_name.split("/")[:-1])
img_dst_path = (
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
@@ -103,6 +98,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
img = Image.open(img_dst_path + image_name.split("/")[-1])
if image_name not in prompt_data.keys():
prompt_data[image_name] = []
prompt_choices = ["Add new"]
prompt_choices += prompt_data[image_name]
return gr.Image.update(value=img), gr.Dropdown.update(
@@ -144,6 +141,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
@@ -171,6 +170,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# write prompt jsonlines file
with open(prompt_path, "w") as f:
for key, value in prompt_data.items():
if not value:
continue
v = value if len(value) > 1 else value[0]
f.write(json.dumps({"file_name": key, "text": v}))
f.write("\n")
@@ -227,7 +228,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
# upload prompt and remove local data
dataset_path = str(shark_root) + "/dataset/" + dataset
dataset_gs_path = gs_url + "/" + dataset + "/"
dataset_gs_path = args.gs_url + "/" + dataset + "/"
os.system(
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
)
@@ -240,8 +241,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
if __name__ == "__main__":
shark_web.launch(
share=False,
share=args.share,
inbrowser=True,
server_name="0.0.0.0",
server_port=8080,
server_port=args.server_port,
)

34
dataset/args.py Normal file
View File

@@ -0,0 +1,34 @@
import argparse
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
##############################################################################
### Dataset Annotator flags
##############################################################################
p.add_argument(
"--gs_url",
type=str,
required=True,
help="URL to datasets in GS bucket",
)
p.add_argument(
"--share",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for generating a public URL",
)
p.add_argument(
"--server_port",
type=int,
default=8080,
help="flag for setting server port",
)
##############################################################################
args = p.parse_args()

View File

@@ -4,6 +4,7 @@ from google.cloud import storage
def get_datasets(gs_url):
datasets = set()
images = dict()
ds_w_prompts = []
storage_client = storage.Client()
bucket_name = gs_url.split("/")[2]
@@ -12,12 +13,17 @@ def get_datasets(gs_url):
for blob in blobs:
dataset_name = blob.name.split("/")[1]
if dataset_name == "":
continue
datasets.add(dataset_name)
file_sub_path = "/".join(blob.name.split("/")[2:])
# check if image or jsonl
if "/" in file_sub_path:
if dataset_name not in images.keys():
images[dataset_name] = []
images[dataset_name] += [file_sub_path]
if dataset_name not in images.keys():
images[dataset_name] = []
return list(datasets), images
# check if image or jsonl
file_sub_path = "/".join(blob.name.split("/")[2:])
if "/" in file_sub_path:
images[dataset_name] += [file_sub_path]
elif "metadata.jsonl" in file_sub_path:
ds_w_prompts.append(dataset_name)
return list(datasets), images, ds_w_prompts