mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
minor optimizations & cleaning (#257)
* use isinstance, some optimizations & whitespace removal
* revert whitespace changes
* revert more whitespace
* some more cleanup
* revert fstring (not a fan of the {{}})
* fix typo
* fix typo
This commit is contained in:
@@ -135,4 +135,4 @@ if __name__ == "__main__":
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save('examples/checkpoint'+str("%.0f" % (accuracy*1.0e6)))
|
||||
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|
||||
|
||||
@@ -5,8 +5,7 @@ from tinygrad.tensor import Tensor
|
||||
|
||||
class MaxPool2d:
|
||||
def __init__(self, kernel_size, stride):
|
||||
if type(kernel_size) == int:
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size)
|
||||
else: self.kernel_size = kernel_size
|
||||
self.stride = stride if (stride is not None) else kernel_size
|
||||
|
||||
@@ -62,9 +61,7 @@ class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, groups = 1, bias = True):
|
||||
self.in_channels, self.out_channels, self.stride, self.padding, self.groups, self.bias = in_channels, out_channels, stride, padding, groups, bias # Wow this is terrible
|
||||
|
||||
if type(kernel_size) == int:
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
else: self.kernel_size = kernel_size
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
|
||||
assert out_channels % groups == 0 and in_channels % groups == 0
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def letterbox_image(img, inp_dim=608):
|
||||
return canvas
|
||||
|
||||
def add_boxes(img, prediction):
|
||||
if type(prediction) is int: # no predictions
|
||||
if isinstance(prediction, int): # no predictions
|
||||
return img
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
|
||||
Reference in New Issue
Block a user