Fix examples (#540)

* Fix examples

* Remove training in parameters

* Simplify a bit

* Remove extra import

* Fix linter errors

* factor out Device

* NumPy-like semantics for Tensor.__getitem__ (#506)

* Rewrote Tensor.__getitem__ to fix negative indices and add support for np.newaxis/None

* Fixed pad2d

* mypy doesn't know about mlops methods

* normal python behavior for out-of-bounds slicing

* type: ignore

* inlined idxfix

* added comment for __getitem__

* Better comments, better tests, and fixed bug in np.newaxis

* update cpu and torch to hold buffers (#542)

* update cpu and torch to hold buffers

* save lines, and probably faster

* Mypy fun (#541)

* mypy fun

* things are just faster

* running fast

* mypy is fast

* compile.sh

* no gpu hack

* refactor ops_cpu and ops_torch to not subclass

* make weak buffer work

* tensor works

* fix test failing

* cpu/torch cleanups

* no or operator on dict in python 3.8

* that was junk

* fix warnings

* comment and touchup

* dyn add of math ops

* refactor ops_cpu and ops_torch to not share code

* nn/optim.py compiles now

* Reorder imports

* call mkdir only if directory doesn't exist

---------

Co-authored-by: George Hotz <geohot@gmail.com>
Co-authored-by: Mitchell Goff <mitchellgoffpc@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Jacky Lee
2023-02-10 10:09:37 -08:00
committed by GitHub
parent 56a06280c5
commit f08187526f
13 changed files with 115 additions and 178 deletions

View File

@@ -1,16 +1,16 @@
#!/usr/bin/env python3
import gc
import time
from tqdm import trange
from models.efficientnet import EfficientNet
import tinygrad.nn.optim as optim
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.runtime.opencl import CL
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import getenv
import gc
def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
return sum(isinstance(x, Tensor) for x in gc.get_objects())
NUM = getenv("NUM", 2)
BS = getenv("BS", 8)
@@ -63,8 +63,3 @@ if __name__ == "__main__":
cl = time.monotonic()
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")

View File

@@ -2,16 +2,18 @@
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import os
import sys
import io
import ast
import time
import cv2
import numpy as np
np.set_printoptions(suppress=True)
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import fetch, get_parameters
from extra.utils import fetch
from models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)
def infer(model, img):
# preprocess image
@@ -53,15 +55,12 @@ if __name__ == "__main__":
model.load_from_pretrained()
# category labels
import ast
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
# load image and preprocess
from PIL import Image
url = sys.argv[1]
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/geohot/tinygrad/master/docs/stable_diffusion_by_tinygrad.jpg"
if url == 'webcam':
import cv2
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
while 1:
@@ -86,5 +85,4 @@ if __name__ == "__main__":
st = time.time()
out, _ = infer(model, img)
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
print("did inference in %.2f s" % (time.time()-st))
#print("NOT", np.argmin(out.data), np.min(out.data), lbls[np.argmin(out.data)])
print(f"did inference in {(time.time()-st):2f}")

View File

@@ -3,16 +3,16 @@
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
# https://siboehm.com/articles/22/CUDA-MMM
# TODO: gelu is causing nans!
import numpy as np
import time
import numpy as np
from datasets import fetch_cifar
from tinygrad import nn
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.training import train, evaluate
from extra.utils import get_parameters
from tinygrad.ops import GlobalCounters
from tinygrad.llops.ops_gpu import CL
from extra.utils import get_parameters
num_classes = 10
@@ -53,7 +53,6 @@ class SpeedyResNet:
# TODO: this will become @tinygrad.jit
first, cl_cache, loss = True, None, None
from tinygrad.runtime.opencl import CL
def train_step_jitted(model, optimizer, X, Y, enable_jit=False):
global cl_cache, first, loss
GlobalCounters.reset()
@@ -95,7 +94,6 @@ def train_cifar():
X_test,Y_test = fetch_cifar(train=False)
Xt, Yt = fetch_batch(X_test, Y_test, BS=BS)
model = SpeedyResNet()
def make_lr(x): return Tensor([x]).realize()
optimizer = optim.SGD(get_parameters(model), lr=0.001)
#optimizer = optim.Adam(get_parameters(model), lr=3e-4)

View File

@@ -3,20 +3,17 @@ import os
import sys
import numpy as np
from tqdm import tqdm
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
import torch
from torchvision.utils import make_grid, save_image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from extra.utils import get_parameters
import tinygrad.nn.optim as optim
from extra.utils import get_parameters
from datasets import fetch_mnist
from torchvision.utils import make_grid, save_image
import torch
GPU = getenv("GPU")
class LinearGen:
def __init__(self):
lv = 128
self.l1 = Tensor.uniform(128, 256)
self.l2 = Tensor.uniform(256, 512)
self.l3 = Tensor.uniform(512, 1024)
@@ -31,7 +28,6 @@ class LinearGen:
class LinearDisc:
def __init__(self):
in_sh = 784
self.l1 = Tensor.uniform(784, 1024)
self.l2 = Tensor.uniform(1024, 512)
self.l3 = Tensor.uniform(512, 256)
@@ -124,7 +120,7 @@ if __name__ == "__main__":
loss_g = 0.0
loss_d = 0.0
print(f"Epoch {epoch} of {epochs}")
for i in tqdm(range(n_steps)):
for _ in tqdm(range(n_steps)):
image = generator_batch()
for step in range(k): # Try with k = 5 or 7.
noise = Tensor(np.random.randn(batch_size,128))
@@ -143,5 +139,4 @@ if __name__ == "__main__":
epoch_loss_g = loss_g / n_steps
epoch_loss_d = loss_d / n_steps
print(f"EPOCH: Generator loss: {epoch_loss_g}, Discriminator loss: {epoch_loss_d}")
else:
print("Training Completed!")
print("Training Completed!")

View File

@@ -1,19 +1,14 @@
#!/usr/bin/env python
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2D
from extra.utils import get_parameters
from datasets import fetch_mnist
from extra.training import train, evaluate, sparse_categorical_crossentropy
import tinygrad.nn.optim as optim
from tinygrad.nn import BatchNorm2D, optim
from tinygrad.helpers import getenv
from datasets import fetch_mnist
from extra.augment import augment_img
from extra.utils import get_parameters
from extra.training import train, evaluate, sparse_categorical_crossentropy
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")

View File

@@ -12,9 +12,9 @@ from collections import namedtuple
import numpy as np
from tqdm import tqdm
from extra.utils import fake_torch_load_zipped, get_child
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm
from extra.utils import fake_torch_load_zipped, get_child
# TODO: refactor AttnBlock, CrossAttention, CLIPAttention to share code
@@ -394,7 +394,7 @@ class CLIPAttention:
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.softmax()
attn_output = attn_weights @ value_states
@@ -429,7 +429,7 @@ class CLIPEncoderLayer:
class CLIPEncoder:
def __init__(self):
self.layers = [CLIPEncoderLayer() for i in range(12)]
def __call__(self, hidden_states, causal_attention_mask):
for l in self.layers:
hidden_states = l(hidden_states, causal_attention_mask)
@@ -556,8 +556,7 @@ class ClipTokenizer:
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
@@ -620,7 +619,7 @@ if __name__ == "__main__":
w = get_child(model, k)
except (AttributeError, KeyError, IndexError):
#traceback.print_exc()
w = None
w = None
#print(f"{str(v.shape):30s}", w.shape if w is not None else w, k)
if w is not None:
assert w.shape == v.shape

View File

@@ -1,31 +1,27 @@
import os
import traceback
import time
from multiprocessing import Process, Queue
import numpy as np
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from extra.utils import get_parameters
from tqdm import trange
from tinygrad.nn import BatchNorm2D
import tinygrad.nn.optim as optim
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from datasets import fetch_cifar
from datasets.imagenet import fetch_batch
from extra.utils import get_parameters
from models.efficientnet import EfficientNet
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
#self.bn1 = BatchNorm2D(inter_chan)
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
#self.bn2 = BatchNorm2D(out_chan)
self.l1 = Tensor.uniform(out_chan*6*6, classes)
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
#x = self.bn1(x)
x = x.conv2d(self.c2).relu().max_pool2d()
#x = self.bn2(x)
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1)
@@ -47,12 +43,10 @@ if __name__ == "__main__":
print("parameter count", len(parameters))
optimizer = optim.Adam(parameters, lr=0.001)
BS, steps = getenv("BS", 64 if TINY else 16)), getenv("STEPS", 2048))
print("training with batch size %d for %d steps" % (BS, steps))
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
print(f"training with batch size {BS} for {steps} steps")
if IMAGENET:
from datasets.imagenet import fetch_batch
from multiprocessing import Process, Queue
def loader(q):
while 1:
try:
@@ -94,8 +88,6 @@ if __name__ == "__main__":
optimizer.step()
opt_time = (time.time()-st)*1000.0
#print(out.cpu().data)
st = time.time()
loss = loss.cpu().data
cat = np.argmax(out.cpu().data, axis=1)

View File

@@ -1,14 +1,12 @@
#!/usr/bin/env python3
import numpy as np
import random
from PIL import Image
from tinygrad.tensor import Device
from tinygrad.nn.optim import Adam
from tinygrad.helpers import getenv
from extra.utils import get_parameters
from extra.training import train, evaluate
from models.resnet import ResNet
from tinygrad.nn.optim import Adam
from tinygrad.helpers import getenv
from datasets import fetch_mnist
@@ -39,7 +37,7 @@ if __name__ == "__main__":
lambda x: x / 255.0,
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
])
for i in range(10):
for _ in range(10):
optim = Adam(get_parameters(model), lr=lr)
train(model, X_train, Y_train, optim, 50, BS=32, transform=transform)
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True, transform=transform)

View File

@@ -2,11 +2,10 @@
import numpy as np
import random
from tinygrad.tensor import Device
from tinygrad.nn.optim import Adam
from extra.utils import get_parameters
from extra.training import train, evaluate
from models.transformer import Transformer
from tinygrad.nn.optim import Adam
# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb
def make_dataset():
@@ -21,13 +20,10 @@ def make_dataset():
ds_Y = np.copy(ds[:, 1:])
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
from tinygrad.nn.optim import Adam
if __name__ == "__main__":
model = Transformer(10, 6, 2, 128, 4, 32)
X_train, Y_train, X_test, Y_test = make_dataset()
lr = 0.003
for i in range(10):

View File

@@ -1,13 +1,13 @@
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD
import examples.yolo.waifu2x
from examples.yolo.kinne import KinneDir
import sys
import os
import random
import json
import numpy
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD
from examples.yolo.kinne import KinneDir
from examples.yolo.waifu2x import image_load, image_save, Vgg7
# amount of context erased by model
CONTEXT = 7
@@ -22,9 +22,8 @@ def get_sample_count(samples_dir):
return 0
def set_sample_count(samples_dir, sc):
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "w")
samples_dir_count_file.write(str(sc) + "\n")
samples_dir_count_file.close()
with open(samples_dir + "/sample_count.txt", "w") as file:
file.write(str(sc) + "\n")
if len(sys.argv) < 2:
print("python3 -m examples.vgg7 import MODELJSON MODELDIR")
@@ -58,7 +57,7 @@ if len(sys.argv) < 2:
sys.exit(1)
cmd = sys.argv[1]
vgg7 = extra.waifu2x.Vgg7()
vgg7 = Vgg7()
def nansbane(p):
if numpy.isnan(numpy.min(p.data)):
@@ -81,7 +80,8 @@ if cmd == "import":
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
os.mkdir(model)
if not os.path.isdir(model):
os.mkdir(model)
load_and_save(model, True)
elif cmd == "execute":
model = sys.argv[2]
@@ -90,7 +90,7 @@ elif cmd == "execute":
load_and_save(model, False)
extra.waifu2x.image_save(out_file, vgg7.forward(Tensor(extra.waifu2x.image_load(in_file))).data)
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).data)
elif cmd == "execute_full":
model = sys.argv[2]
in_file = sys.argv[3]
@@ -98,11 +98,12 @@ elif cmd == "execute_full":
load_and_save(model, False)
extra.waifu2x.image_save(out_file, vgg7.forward_tiled(extra.waifu2x.image_load(in_file), 156))
image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
elif cmd == "new":
model = sys.argv[2]
os.mkdir(model)
if not os.path.isdir(model):
os.mkdir(model)
load_and_save(model, True)
elif cmd == "train":
model = sys.argv[2]
@@ -127,7 +128,6 @@ elif cmd == "train":
except:
# it's fine
print("sample probs could not be loaded - initializing")
pass
if sample_probs is None:
# This stupidly high amount is used to force an initial pass over all samples
@@ -151,8 +151,8 @@ elif cmd == "train":
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
sample_idx = random.randint(0, samples_count - 1)
x_img = extra.waifu2x.image_load(samples_base + "/" + str(sample_idx) + "a.png")
y_img = extra.waifu2x.image_load(samples_base + "/" + str(sample_idx) + "b.png")
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
sample_x = Tensor(x_img, requires_grad = False)
sample_y = Tensor(y_img, requires_grad = False)
@@ -209,8 +209,8 @@ elif cmd == "samplify":
# This bit is interesting because it actually does some work.
# Not much, but some work.
a_img = extra.waifu2x.image_load(a_img)
b_img = extra.waifu2x.image_load(b_img)
a_img = image_load(a_img)
b_img = image_load(b_img)
# as with the main library body,
# Y X order is used here
@@ -238,14 +238,13 @@ elif cmd == "samplify":
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
extra.waifu2x.image_save(samples_base + "/" + str(samples_count) + "a.png", patch_x)
extra.waifu2x.image_save(samples_base + "/" + str(samples_count) + "b.png", patch_y)
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
samples_count += 1
samples_added += 1
print("Added " + str(samples_added) + " samples")
print(f"Added {str(samples_added)} samples")
set_sample_count(samples_base, samples_count)
else:
print("unknown command")

View File

@@ -1,5 +1,11 @@
import ast
import io
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from models.vit import ViT
from extra.utils import fetch
"""
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
import tensorflow as tf
@@ -9,11 +15,6 @@ with tf.io.gfile.GFile(fn, "rb") as f:
g.write(dat)
"""
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from models.vit import ViT
Tensor.training = False
if getenv("LARGE", 0) == 1:
m = ViT(embed_dim=768, num_heads=12)
@@ -23,9 +24,6 @@ else:
m.load_from_pretrained()
# category labels
import ast
import io
from extra.utils import fetch
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
@@ -33,7 +31,6 @@ lbls = ast.literal_eval(lbls.decode('utf-8'))
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
# junk
from PIL import Image
img = Image.open(io.BytesIO(fetch(url)))
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
@@ -50,5 +47,3 @@ out = m.forward(Tensor(img))
outnp = out.cpu().data.ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice], lbls[choice])

View File

@@ -35,12 +35,8 @@ class KinneDir:
It is important that if you wish to save in the current directory,
you use ".", not the empty string.
"""
if save:
try:
os.mkdir(base)
except:
# Silence the exception - the directory may (and if reading, does) already exist.
pass
if save and not os.path.isdir(base):
os.mkdir(base)
self.base = base + "/snoop_bin_"
self.next_part_index = 0
self.save = save

View File

@@ -1,21 +1,19 @@
# https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg
# running
# running
import sys
import io
import time
import cv2
import numpy as np
np.set_printoptions(suppress=True)
from PIL import Image
from tinygrad.tensor import Tensor
from extra.utils import fetch, get_parameters
from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d
from tinygrad.nn import BatchNorm2D, Conv2d
from tinygrad.helpers import getenv
from extra.utils import fetch, get_parameters
from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d
np.set_printoptions(suppress=True)
GPU = getenv("GPU")
import cv2
from PIL import Image
def show_labels(prediction, confidence = 0.5, num_classes = 80):
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
coco_labels = coco_labels.decode('utf-8').split('\n')
@@ -29,7 +27,7 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
def numpy_max(input, dim):
# Input -> tensor (10x8)
return np.amax(input, axis=dim), np.argmax(input, axis=dim)
# Iterate over batches
for i in range(prediction.shape[0]):
img_pred = prediction[i]
@@ -40,7 +38,7 @@ def show_labels(prediction, confidence = 0.5, num_classes = 80):
image_pred = np.concatenate(seq, axis=1)
non_zero_ind = np.nonzero(image_pred[:,4])[0]
assert(all(image_pred[non_zero_ind,0] > 0))
assert all(image_pred[non_zero_ind,0] > 0)
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
try:
@@ -59,10 +57,9 @@ def letterbox_image(img, inp_dim=608):
new_w = int(img_w * min(w/img_w, h/img_h))
new_h = int(img_h * min(w/img_w, h/img_h))
resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC)
canvas = np.full((inp_dim[1], inp_dim[0], 3), 128)
canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w, :] = resized_image
return canvas
def add_boxes(img, prediction):
@@ -89,7 +86,6 @@ def add_boxes(img, prediction):
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
return img
def bbox_iou(box1, box2):
@@ -127,17 +123,16 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
conf_mask = (prediction[:,:,4] > confidence)
conf_mask = np.expand_dims(conf_mask, 2)
prediction = prediction * conf_mask
# Non max suppression
box_corner = prediction
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
prediction[:,:,:4] = box_corner[:,:,:4]
batch_size = prediction.shape[0]
write = False
# Process img
@@ -146,7 +141,7 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
def numpy_max(input, dim):
# Input -> tensor (10x8)
return np.amax(input, axis=dim), np.argmax(input, axis=dim)
max_conf, max_conf_score = numpy_max(img_pred[:,5:5 + num_classes], 1)
max_conf_score = np.expand_dims(max_conf_score, axis=1)
max_conf = np.expand_dims(max_conf, axis=1)
@@ -154,7 +149,7 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
image_pred = np.concatenate(seq, axis=1)
non_zero_ind = np.nonzero(image_pred[:,4])[0]
assert(all(image_pred[non_zero_ind,0] > 0))
assert all(image_pred[non_zero_ind,0] > 0)
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
try:
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
@@ -165,7 +160,7 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
if image_pred_.shape[0] == 0:
print("No detections found!")
return 0
def unique(tensor):
tensor_np = tensor
unique_np = np.unique(tensor_np)
@@ -179,35 +174,34 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0
class_mask_ind = np.squeeze(np.nonzero(cls_mask[:,-2]))
# class_mask_ind = np.nonzero()
image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7))
# sort the detections such that the entry with the maximum objectness
# confidence is at the top
conf_sort_index = np.argsort(image_pred_class[:,4])
image_pred_class = image_pred_class[conf_sort_index]
idx = image_pred_class.shape[0] #Number of detections
for i in range(idx):
#Get the IOUs of all boxes that come after the one we are looking at
#in the loop
# Get the IOUs of all boxes that come after the one we are looking at in the loop
try:
ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:])
except ValueError:
break
except IndexError:
break
# Zero out all the detections that have IoU > threshold
iou_mask = np.expand_dims((ious < nms_conf), axis=1)
image_pred_class[i+1:] *= iou_mask
# Remove the non-zero entries
non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:,4]))
image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7))
image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7))
batch_ind = np.array([[0]])
seq = (batch_ind, image_pred_class)
if not write:
output = np.concatenate(seq, 1)
write = True
@@ -228,10 +222,9 @@ def resize(img, inp_dim=(608, 608)):
new_w = int(img_w * min(w/img_w, h/img_h))
new_h = int(img_h * min(w/img_w, h/img_h))
resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC)
canvas = np.full((inp_dim[1], inp_dim[0], 3), 128)
canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w, :] = resized_image
return canvas
def infer(model, img):
@@ -275,12 +268,12 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
grid_size = inp_dim // stride
bbox_attrs = 5 + num_classes
num_anchors = len(anchors)
prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size))
# Original PyTorch: transpose(1, 2) -> For some reason numpy.transpose order has to be reversed?
prediction = prediction.transpose(order=(0, 2, 1))
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
# st = time.time()
prediction_cpu = prediction.cpu().data
# print('put on CPU in %.2f s' % (time.time() - st))
@@ -290,11 +283,11 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
# TODO: Fix this
def dsigmoid(data):
return 1/(1+np.exp(-data))
prediction_cpu[:,:,0] = dsigmoid(prediction_cpu[:,:,0])
prediction_cpu[:,:,1] = dsigmoid(prediction_cpu[:,:,1])
prediction_cpu[:,:,4] = dsigmoid(prediction_cpu[:,:,4])
# Add the center offsets
grid = np.arange(grid_size)
a, b = np.meshgrid(grid, grid)
@@ -315,7 +308,6 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
prediction_cpu[:,:,2:4] = np.exp(prediction_cpu[:,:,2:4])*anchors
prediction_cpu[:,:,5: 5 + num_classes] = dsigmoid((prediction_cpu[:,:, 5 : 5 + num_classes]))
prediction_cpu[:,:,:4] *= stride
prediction.gpu_()
return Tensor(prediction_cpu)
@@ -352,19 +344,19 @@ class Darknet:
pad = (int(x["size"]) - 1) // 2
else:
pad = 0
conv = Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias = bias)
module.append(conv)
# BatchNorm2d
if batch_normalize:
bn = BatchNorm2D(filters, eps=1e-05, training=True, track_running_stats=True)
bn = BatchNorm2D(filters, eps=1e-05, track_running_stats=True)
module.append(bn)
# LeakyReLU activation
if activation == "leaky":
module.append(LeakyReLU(0.1))
# TODO: Add tiny model
elif module_type == "maxpool":
size = int(x["size"])
@@ -375,7 +367,7 @@ class Darknet:
elif module_type == "upsample":
upsample = Upsample(scale_factor = 2, mode = "nearest")
module.append(upsample)
elif module_type == "route":
x["layers"] = x["layers"].split(",")
# Start of route
@@ -393,11 +385,11 @@ class Darknet:
filters = output_filters[index + start] + output_filters[index + end]
else:
filters = output_filters[index + start]
# Shortcut corresponds to skip connection
elif module_type == "shortcut":
module.append(EmptyLayer())
elif module_type == "yolo":
mask = x["mask"].split(",")
mask = [int(x) for x in mask]
@@ -409,15 +401,15 @@ class Darknet:
detection = DetectionLayer(anchors)
module.append(detection)
# Append to module_list
module_list.append(module)
if filters is not None:
prev_filters = filters
output_filters.append(filters)
return (net_info, module_list)
def dump_weights(self):
for i in range(len(self.module_list)):
module_type = self.blocks[i + 1]["type"]
@@ -432,7 +424,7 @@ class Darknet:
print(conv.bias.cpu().data[0][0:5])
else:
print("None biases for layer", i)
def load_weights(self, url):
weights = fetch(url)
# First 5 values (major, minor, subversion, Images seen)
@@ -456,10 +448,10 @@ class Darknet:
batch_normalize = int(self.blocks[i + 1]["batch_normalize"])
except: # no batchnorm, load conv weights + biases
batch_normalize = 0
conv = model[0]
if (batch_normalize):
if batch_normalize:
bn = model[1]
# Get the number of weights of batchnorm
@@ -502,7 +494,7 @@ class Darknet:
# Copy
conv.bias = conv_biases
# Load weighys for conv layers
num_weights = numel(conv.weight)
@@ -512,9 +504,6 @@ class Darknet:
conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape))
conv.weight = conv_weights
def forward(self, x):
modules = self.blocks[1:]
outputs = {} # Cached outputs for route layer
@@ -522,36 +511,28 @@ class Darknet:
for i, module in enumerate(modules):
module_type = (module["type"])
st = time.time()
if module_type == "convolutional" or module_type == "upsample":
for index, layer in enumerate(self.module_list[i]):
for layer in self.module_list[i]:
x = layer(x)
elif module_type == "route":
layers = module["layers"]
layers = [int(a) for a in layers]
if (layers[0]) > 0:
layers[0] = layers[0] - i
if len(layers) == 1:
x = outputs[i + (layers[0])]
else:
if (layers[1]) > 0: layers[1] = layers[1] - i
map1 = outputs[i + layers[0]]
map2 = outputs[i + layers[1]]
x = Tensor(np.concatenate((map1.cpu().data, map2.cpu().data), 1))
elif module_type == "shortcut":
from_ = int(module["from"])
x = outputs[i - 1] + outputs[i + from_]
elif module_type == "yolo":
anchors = self.module_list[i][0].anchors
inp_dim = int(self.net_info["height"])
# inp_dim = 416
num_classes = int(module["classes"])
# Transform
x = predict_transform(x, inp_dim, anchors, num_classes)
@@ -560,10 +541,10 @@ class Darknet:
write = 1
else:
detections = Tensor(np.concatenate((detections.cpu().data, x.cpu().data), 1))
# print(module_type, 'layer took %.2f s' % (time.time() - st))
outputs[i] = x
return detections # Return detections
if __name__ == "__main__":
@@ -614,12 +595,12 @@ if __name__ == "__main__":
img = cv2.imdecode(np.fromstring(img_stream.read(), np.uint8), 1)
else:
img = cv2.imread(url)
# Predict
st = time.time()
print('running inference…')
prediction = infer(model, img)
print('did inference in %.2f s' % (time.time() - st))
print(f'did inference in {(time.time() - st):2f} s')
labels = show_labels(prediction)
prediction = process_results(prediction)