add name support to fetch (#2407)

* add name support

* use fetch in gpt2

* remove requests from main lib, networkx also optional

* umm, keep that assert

* updates to fetch

* i love the walrus so much

* stop bundling mnist with tinygrad

* err, https

* download cache names

* add DOWNLOAD_CACHE_VERSION

* need env.

* ugh, wrong path

* replace get_child
This commit is contained in:
George Hotz
2023-11-23 14:16:17 -08:00
committed by GitHub
parent 397c093656
commit 095e2ced61
16 changed files with 73 additions and 79 deletions

View File

@@ -1,7 +1,8 @@
import math
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d
from extra.utils import get_child
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
@@ -142,9 +143,7 @@ class EfficientNet:
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
}
from extra.utils import fetch_as_file
from tinygrad.nn.state import torch_load
b0 = torch_load(fetch_as_file(model_urls[self.number]))
b0 = torch_load(fetch(model_urls[self.number]))
for k,v in b0.items():
if k.endswith("num_batches_tracked"): continue
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']: