mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
Update yolov3.py (#2680)
The current yolov3 example is broken with the current implementation of of fetch in the helpers. I was tempted to fix the helpers instead but that could have just as well broken other examples.
This commit is contained in:
@@ -38,7 +38,7 @@ def show_labels(prediction, confidence=0.5, num_classes=80):
|
||||
def add_boxes(img, prediction):
|
||||
if isinstance(prediction, int): # no predictions
|
||||
return img
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes()
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
height, width = img.shape[0:2]
|
||||
scale_factor = 608 / width
|
||||
@@ -281,7 +281,7 @@ class Darknet:
|
||||
print("None biases for layer", i)
|
||||
|
||||
def load_weights(self, url):
|
||||
weights = np.frombuffer(fetch(url), dtype=np.float32)[5:]
|
||||
weights = np.frombuffer(fetch(url).read_bytes(), dtype=np.float32)[5:]
|
||||
ptr = 0
|
||||
for i in range(len(self.module_list)):
|
||||
module_type = self.blocks[i + 1]["type"]
|
||||
@@ -369,7 +369,7 @@ class Darknet:
|
||||
return detections
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg'))
|
||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg').read_bytes())
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
|
||||
if len(sys.argv) > 1:
|
||||
@@ -392,7 +392,7 @@ if __name__ == "__main__":
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
elif url.startswith('http'):
|
||||
img_stream = io.BytesIO(fetch(url))
|
||||
img_stream = io.BytesIO(fetch(url).read_bytes())
|
||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||
else:
|
||||
img = cv2.imread(url)
|
||||
|
||||
Reference in New Issue
Block a user