not forcing 3.9 for a stupid type

This commit is contained in:
George Hotz
2021-10-30 16:52:40 -07:00
parent 114f6ca3fd
commit 7472a7ebe2
2 changed files with 8 additions and 12 deletions

View File

@@ -9,16 +9,14 @@ from PIL import Image
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
def _load_labels() -> list[str]:
def _load_labels():
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
return ast.literal_eval(labels_filename.read_text())
_LABELS = _load_labels()
def _infer(model: EfficientNet, img) -> str:
def _infer(model: EfficientNet, img):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))