mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add asserts for non-zero indices (#264)
This commit is contained in:
@@ -39,7 +39,9 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
|
||||
seq = (img_pred[:,:5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0] # TODO: Check if this is right
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0]
|
||||
assert(all(image_pred[non_zero_ind,0] > 0))
|
||||
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
try:
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
@@ -151,7 +153,8 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
|
||||
seq = (img_pred[:,:5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0] # TODO: Check if this is right
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0]
|
||||
assert(all(image_pred[non_zero_ind,0] > 0))
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
try:
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
@@ -581,8 +584,10 @@ if __name__ == "__main__":
|
||||
params = get_parameters(model)
|
||||
[x.gpu_() for x in params]
|
||||
|
||||
url = sys.argv[1]
|
||||
# url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||
|
||||
img = None
|
||||
# We use cv2 because for some reason, cv2 imread produces better results?
|
||||
@@ -616,8 +621,9 @@ if __name__ == "__main__":
|
||||
prediction = infer(model, img)
|
||||
print('did inference in %.2f s' % (time.time() - st))
|
||||
|
||||
labels = show_labels(prediction)
|
||||
prediction = process_results(prediction)
|
||||
# print(prediction)
|
||||
boxes = add_boxes(imresize(img, 608, 608), prediction)
|
||||
# Save img
|
||||
cv2.imwrite('boxes.jpg', boxes)
|
||||
cv2.imwrite('boxes.jpg', boxes)
|
||||
|
||||
Reference in New Issue
Block a user