not forcing 3.9 for a stupid type

This commit is contained in:
George Hotz
2021-10-30 16:52:40 -07:00
parent 114f6ca3fd
commit 7472a7ebe2
2 changed files with 8 additions and 12 deletions

View File

@@ -22,10 +22,10 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -46,10 +46,10 @@ jobs:
run: sudo apt-get update
- name: Install OpenCL
run: sudo apt-get install pocl-opencl-icd
- name: Set up Python 3.9
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[testing]'
- name: Run Pytest
@@ -66,14 +66,12 @@ jobs:
run: sudo apt-get update
- name: Install OpenCL
run: sudo apt-get install pocl-opencl-icd
- name: Set up Python 3.9
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[gpu,testing]'
- name: Run Pytest
run: GPU=1 python -m pytest -s -v

View File

@@ -9,16 +9,14 @@ from PIL import Image
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
def _load_labels() -> list[str]:
def _load_labels():
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
return ast.literal_eval(labels_filename.read_text())
_LABELS = _load_labels()
def _infer(model: EfficientNet, img) -> str:
def _infer(model: EfficientNet, img):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))