fix img_ids repeating its values

This commit is contained in:
Francis Lata
2025-02-02 19:21:46 +00:00
parent 17ae62d741
commit 932cf4b7f2

View File

@@ -393,6 +393,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
ann = dataset.loadAnns(dataset.getAnnIds(img_id:=img["id"]))
tgt = prepare_target(ann, img_id, (img["height"], img["width"]))
if img_ids is not None:
img_ids[idx] = img_id
if isinstance(boxes, list):
boxes[idx] = tgt["boxes"]
@@ -417,8 +420,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
if val:
matches, anchors = None, None
boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
img_ids, boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
else:
img_ids = None
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
@@ -464,7 +468,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
if val:
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
image_ids[bc * batch_size:(bc + 1) * batch_size],
img_ids[bc * batch_size:(bc + 1) * batch_size],
boxes[bc * batch_size:(bc + 1) * batch_size],
labels[bc * batch_size:(bc + 1) * batch_size],
Cookie(bc))