mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Address TODOs for dataset annotator (#872)
- add args usage, pass gs_url by CL flag - add support for no existing prompts
This commit is contained in:
@@ -16,7 +16,7 @@ pip install -r requirements.txt
|
||||
python annotation_tool.py
|
||||
```
|
||||
|
||||
<img width="1308" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214191759-24cc5fe6-cd53-4099-87f6-707068f8888d.png">
|
||||
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
|
||||
|
||||
* Select a dataset from `Dataset` dropdown list
|
||||
* Select an image from `Image` dropdown list
|
||||
|
||||
@@ -2,15 +2,12 @@ import gradio as gr
|
||||
import json
|
||||
import jsonlines
|
||||
import os
|
||||
from args import args
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from utils import get_datasets
|
||||
|
||||
|
||||
# TODO: pass gs_url as a command line flag
|
||||
# see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
|
||||
gs_url = "gs://shark-datasets/portraits"
|
||||
|
||||
shark_root = Path(__file__).parent.parent
|
||||
demo_css = shark_root.joinpath("web/demo.css").resolve()
|
||||
nodlogo_loc = shark_root.joinpath(
|
||||
@@ -30,15 +27,15 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
elem_id="top_logo",
|
||||
).style(width=150, height=100)
|
||||
|
||||
datasets, images = get_datasets(gs_url)
|
||||
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
|
||||
prompt_data = dict()
|
||||
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
# TODO: add multiselect dataset
|
||||
# TODO: add multiselect dataset, there is a gradio version conflict
|
||||
dataset = gr.Dropdown(label="Dataset", choices=datasets)
|
||||
image_name = gr.Dropdown(label="Image", choices=[])
|
||||
|
||||
with gr.Row(elem_id="ui_body", visible=True):
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
# TODO: add ability to search image by typing
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
image = gr.Image(type="filepath").style(height=512)
|
||||
@@ -61,27 +58,26 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
finish = gr.Button("Finish")
|
||||
|
||||
def filter_datasets(dataset):
|
||||
# TODO: execute finish process when switching dataset
|
||||
if dataset is None:
|
||||
return gr.Dropdown.update(value=None, choices=[])
|
||||
|
||||
# create the dataset dir if doesn't exist and download prompt file
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
# TODO: check if metadata.jsonl exists
|
||||
prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
|
||||
if not os.path.exists(dataset_path):
|
||||
os.mkdir(dataset_path)
|
||||
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
|
||||
|
||||
# read prompt jsonlines file
|
||||
prompt_data.clear()
|
||||
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
|
||||
for line in reader.iter(type=dict, skip_invalid=True):
|
||||
prompt_data[line["file_name"]] = (
|
||||
[line["text"]]
|
||||
if type(line["text"]) is str
|
||||
else line["text"]
|
||||
)
|
||||
if dataset in ds_w_prompts:
|
||||
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
|
||||
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
|
||||
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
|
||||
for line in reader.iter(type=dict, skip_invalid=True):
|
||||
prompt_data[line["file_name"]] = (
|
||||
[line["text"]]
|
||||
if type(line["text"]) is str
|
||||
else line["text"]
|
||||
)
|
||||
|
||||
return gr.Dropdown.update(choices=images[dataset])
|
||||
|
||||
@@ -92,8 +88,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
|
||||
|
||||
# download and load the image
|
||||
# TODO: remove previous image if change image from dropdown
|
||||
img_gs_path = gs_url + "/" + dataset + "/" + image_name
|
||||
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
|
||||
img_sub_path = "/".join(image_name.split("/")[:-1])
|
||||
img_dst_path = (
|
||||
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
|
||||
@@ -103,6 +98,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
|
||||
img = Image.open(img_dst_path + image_name.split("/")[-1])
|
||||
|
||||
if image_name not in prompt_data.keys():
|
||||
prompt_data[image_name] = []
|
||||
prompt_choices = ["Add new"]
|
||||
prompt_choices += prompt_data[image_name]
|
||||
return gr.Image.update(value=img), gr.Dropdown.update(
|
||||
@@ -144,6 +141,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
# write prompt jsonlines file
|
||||
with open(prompt_path, "w") as f:
|
||||
for key, value in prompt_data.items():
|
||||
if not value:
|
||||
continue
|
||||
v = value if len(value) > 1 else value[0]
|
||||
f.write(json.dumps({"file_name": key, "text": v}))
|
||||
f.write("\n")
|
||||
@@ -171,6 +170,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
# write prompt jsonlines file
|
||||
with open(prompt_path, "w") as f:
|
||||
for key, value in prompt_data.items():
|
||||
if not value:
|
||||
continue
|
||||
v = value if len(value) > 1 else value[0]
|
||||
f.write(json.dumps({"file_name": key, "text": v}))
|
||||
f.write("\n")
|
||||
@@ -227,7 +228,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
|
||||
# upload prompt and remove local data
|
||||
dataset_path = str(shark_root) + "/dataset/" + dataset
|
||||
dataset_gs_path = gs_url + "/" + dataset + "/"
|
||||
dataset_gs_path = args.gs_url + "/" + dataset + "/"
|
||||
os.system(
|
||||
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
|
||||
)
|
||||
@@ -240,8 +241,8 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
|
||||
if __name__ == "__main__":
|
||||
shark_web.launch(
|
||||
share=False,
|
||||
share=args.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=8080,
|
||||
server_port=args.server_port,
|
||||
)
|
||||
|
||||
34
dataset/args.py
Normal file
34
dataset/args.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Dataset Annotator flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--gs_url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="URL to datasets in GS bucket",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for generating a public URL",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="flag for setting server port",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
|
||||
args = p.parse_args()
|
||||
@@ -4,6 +4,7 @@ from google.cloud import storage
|
||||
def get_datasets(gs_url):
|
||||
datasets = set()
|
||||
images = dict()
|
||||
ds_w_prompts = []
|
||||
|
||||
storage_client = storage.Client()
|
||||
bucket_name = gs_url.split("/")[2]
|
||||
@@ -12,12 +13,17 @@ def get_datasets(gs_url):
|
||||
|
||||
for blob in blobs:
|
||||
dataset_name = blob.name.split("/")[1]
|
||||
if dataset_name == "":
|
||||
continue
|
||||
datasets.add(dataset_name)
|
||||
file_sub_path = "/".join(blob.name.split("/")[2:])
|
||||
# check if image or jsonl
|
||||
if "/" in file_sub_path:
|
||||
if dataset_name not in images.keys():
|
||||
images[dataset_name] = []
|
||||
images[dataset_name] += [file_sub_path]
|
||||
if dataset_name not in images.keys():
|
||||
images[dataset_name] = []
|
||||
|
||||
return list(datasets), images
|
||||
# check if image or jsonl
|
||||
file_sub_path = "/".join(blob.name.split("/")[2:])
|
||||
if "/" in file_sub_path:
|
||||
images[dataset_name] += [file_sub_path]
|
||||
elif "metadata.jsonl" in file_sub_path:
|
||||
ds_w_prompts.append(dataset_name)
|
||||
|
||||
return list(datasets), images, ds_w_prompts
|
||||
|
||||
Reference in New Issue
Block a user