mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update image to be float32
This commit is contained in:
@@ -403,7 +403,7 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
|
||||
queue_in, queue_out = Queue(), Queue()
|
||||
procs, data_out_count = [], [0] * batch_count
|
||||
shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.uint8)
|
||||
shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
|
||||
shm_y_boxes, Y_boxes = _setup_shared_mem("retinanet_y_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_y_labels, Y_labels = _setup_shared_mem("retinanet_y_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si
|
||||
import torchvision.transforms.functional as F
|
||||
img_size = img.size[::-1]
|
||||
img = F.resize(img, size=size)
|
||||
img = np.array(img)
|
||||
img = np.array(img, dtype=np.float32)
|
||||
|
||||
if tgt is not None:
|
||||
ratios = [s / s_orig for s, s_orig in zip(size, img_size)]
|
||||
|
||||
Reference in New Issue
Block a user