Tensor.from_url API [pr] (#7210)

* Tensor.fetch API [pr]

* update docs

* from_url
This commit is contained in:
George Hotz
2024-10-22 13:54:17 +07:00
committed by GitHub
parent be64ac417e
commit 4438d6a467
3 changed files with 19 additions and 2 deletions

View File

@@ -9,7 +9,11 @@
::: tinygrad.Tensor.full_like
::: tinygrad.Tensor.zeros_like
::: tinygrad.Tensor.ones_like
## Creation (external)
::: tinygrad.Tensor.from_blob
::: tinygrad.Tensor.from_url
## Creation (random)

View File

@@ -4,7 +4,7 @@ from tinygrad.nn.state import tar_extract
def mnist(device=None, fashion=False):
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
def _mnist(file): return Tensor(fetch(base_url+file, gunzip=True))
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
from tinygrad.device import Device, Buffer, BufferOptions
@@ -422,6 +422,19 @@ class Tensor:
del r.lazydata.srcs # fake realize
return r
@staticmethod
def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
"""
Create a Tensor from a URL.
This is the preferred way to access Internet resources.
It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
This also will soon become lazy (when possible) and not print progress without DEBUG.
THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
"""
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
_seed: int = int(time.time())
_device_seeds: Dict[str, Tensor] = {}
_device_rng_counters: Dict[str, Tensor] = {}