Add asserts for non-zero indices (#264)

This commit is contained in:
Jacky Lee
2021-06-13 21:14:46 -07:00
committed by GitHub
parent 508ced114c
commit 611d81dcb4

View File

@@ -39,7 +39,9 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
seq = (img_pred[:,:5], max_conf, max_conf_score)
image_pred = np.concatenate(seq, axis=1)
non_zero_ind = np.nonzero(image_pred[:,4])[0] # TODO: Check if this is right
non_zero_ind = np.nonzero(image_pred[:,4])[0]
assert(all(image_pred[non_zero_ind,0] > 0))
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
try:
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
@@ -151,7 +153,8 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
seq = (img_pred[:,:5], max_conf, max_conf_score)
image_pred = np.concatenate(seq, axis=1)
non_zero_ind = np.nonzero(image_pred[:,4])[0] # TODO: Check if this is right
non_zero_ind = np.nonzero(image_pred[:,4])[0]
assert(all(image_pred[non_zero_ind,0] > 0))
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
try:
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
@@ -581,8 +584,10 @@ if __name__ == "__main__":
params = get_parameters(model)
[x.gpu_() for x in params]
url = sys.argv[1]
# url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
if len(sys.argv) > 1:
url = sys.argv[1]
else:
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
img = None
# We use cv2 because for some reason, cv2 imread produces better results?
@@ -616,8 +621,9 @@ if __name__ == "__main__":
prediction = infer(model, img)
print('did inference in %.2f s' % (time.time() - st))
labels = show_labels(prediction)
prediction = process_results(prediction)
# print(prediction)
boxes = add_boxes(imresize(img, 608, 608), prediction)
# Save img
cv2.imwrite('boxes.jpg', boxes)
cv2.imwrite('boxes.jpg', boxes)