mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix img_ids repeating its values
This commit is contained in:
@@ -393,6 +393,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
ann = dataset.loadAnns(dataset.getAnnIds(img_id:=img["id"]))
|
||||
tgt = prepare_target(ann, img_id, (img["height"], img["width"]))
|
||||
|
||||
if img_ids is not None:
|
||||
img_ids[idx] = img_id
|
||||
|
||||
if isinstance(boxes, list):
|
||||
boxes[idx] = tgt["boxes"]
|
||||
|
||||
@@ -417,8 +420,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
|
||||
if val:
|
||||
matches, anchors = None, None
|
||||
boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
|
||||
img_ids, boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
|
||||
else:
|
||||
img_ids = None
|
||||
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
@@ -464,7 +468,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
|
||||
if val:
|
||||
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
|
||||
image_ids[bc * batch_size:(bc + 1) * batch_size],
|
||||
img_ids[bc * batch_size:(bc + 1) * batch_size],
|
||||
boxes[bc * batch_size:(bc + 1) * batch_size],
|
||||
labels[bc * batch_size:(bc + 1) * batch_size],
|
||||
Cookie(bc))
|
||||
|
||||
Reference in New Issue
Block a user