mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.from_url API [pr] (#7210)
* Tensor.fetch API [pr] * update docs * from_url
This commit is contained in:
@@ -9,7 +9,11 @@
|
||||
::: tinygrad.Tensor.full_like
|
||||
::: tinygrad.Tensor.zeros_like
|
||||
::: tinygrad.Tensor.ones_like
|
||||
|
||||
## Creation (external)
|
||||
|
||||
::: tinygrad.Tensor.from_blob
|
||||
::: tinygrad.Tensor.from_url
|
||||
|
||||
## Creation (random)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.nn.state import tar_extract
|
||||
|
||||
def mnist(device=None, fashion=False):
|
||||
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||
def _mnist(file): return Tensor(fetch(base_url+file, gunzip=True))
|
||||
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
|
||||
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
@@ -422,6 +422,19 @@ class Tensor:
|
||||
del r.lazydata.srcs # fake realize
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
|
||||
"""
|
||||
Create a Tensor from a URL.
|
||||
|
||||
This is the preferred way to access Internet resources.
|
||||
It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
|
||||
This also will soon become lazy (when possible) and not print progress without DEBUG.
|
||||
|
||||
THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
|
||||
"""
|
||||
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_device_seeds: Dict[str, Tensor] = {}
|
||||
_device_rng_counters: Dict[str, Tensor] = {}
|
||||
|
||||
Reference in New Issue
Block a user