mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
[ready] Replacing os with pathlib (#1708)
* replace os.path with pathlib * safe convert dirnames to pathlib * replace all os.path.join * fix cuda error * change main chunk * Reviewer fixes * fix vgg * Fixed everything * Final fixes * ensure consistency * Change all parent.parent... to parents
This commit is contained in:
@@ -10,9 +10,9 @@ class BertForQuestionAnswering:
|
||||
self.qa_outputs = Linear(hidden_size, 2)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parent.parent / "weights/bert_for_qa.pt"
|
||||
fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt"
|
||||
download_file("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn)
|
||||
fn_vocab = Path(__file__).parent.parent / "weights/bert_vocab.txt"
|
||||
fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt"
|
||||
download_file("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
||||
|
||||
import torch
|
||||
|
||||
@@ -60,7 +60,7 @@ class RNNT:
|
||||
return out.realize()
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parent.parent / "weights/rnnt.pt"
|
||||
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
||||
download_file("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
||||
|
||||
import torch
|
||||
|
||||
@@ -46,7 +46,7 @@ class UNet3D:
|
||||
return x
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parent.parent / "weights" / "unet-3d.ckpt"
|
||||
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
|
||||
download_file("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
|
||||
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
||||
for k, v in state_dict.items():
|
||||
|
||||
Reference in New Issue
Block a user