mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add name support to fetch (#2407)
* add name support * use fetch in gpt2 * remove requests from main lib, networkx also optional * umm, keep that assert * updates to fetch * i love the walrus so much * stop bundling mnist with tinygrad * err, https * download cache names * add DOWNLOAD_CACHE_VERSION * need env. * ugh, wrong path * replace get_child
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import math
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
from extra.utils import get_child
|
||||
from tinygrad.helpers import get_child, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
@@ -142,9 +143,7 @@ class EfficientNet:
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
}
|
||||
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.nn.state import torch_load
|
||||
b0 = torch_load(fetch_as_file(model_urls[self.number]))
|
||||
b0 = torch_load(fetch(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
|
||||
Reference in New Issue
Block a user