minor optimizations & cleaning (#257)

* use isinstance, some optimizations & whitespace removal

* revert whitespace changes

* revert more whitespace

* some more cleanup

* revert fstring (not a fan of the {{}})

* fix typo

* fix typo
This commit is contained in:
Josh Smith
2021-06-02 12:57:15 -04:00
committed by GitHub
parent 74e874cc0d
commit ad756f6112
9 changed files with 31 additions and 36 deletions

View File

@@ -135,4 +135,4 @@ if __name__ == "__main__":
X_aug = X_train if epoch == 1 else augment_img(X_train)
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
accuracy = evaluate(model, X_test, Y_test, BS=BS)
model.save('examples/checkpoint'+str("%.0f" % (accuracy*1.0e6)))
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')

View File

@@ -5,8 +5,7 @@ from tinygrad.tensor import Tensor
class MaxPool2d:
def __init__(self, kernel_size, stride):
if type(kernel_size) == int:
self.kernel_size = (kernel_size, kernel_size)
if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size)
else: self.kernel_size = kernel_size
self.stride = stride if (stride is not None) else kernel_size
@@ -62,9 +61,7 @@ class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, groups = 1, bias = True):
self.in_channels, self.out_channels, self.stride, self.padding, self.groups, self.bias = in_channels, out_channels, stride, padding, groups, bias # Wow this is terrible
if type(kernel_size) == int:
self.kernel_size = (kernel_size, kernel_size)
else: self.kernel_size = kernel_size
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
assert out_channels % groups == 0 and in_channels % groups == 0

View File

@@ -64,7 +64,7 @@ def letterbox_image(img, inp_dim=608):
return canvas
def add_boxes(img, prediction):
if type(prediction) is int: # no predictions
if isinstance(prediction, int): # no predictions
return img
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
coco_labels = coco_labels.decode('utf-8').split('\n')