mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
some docstrings (#4201)
* feat: create and data access docstrings * fix: linter --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -91,8 +91,9 @@ plugins:
|
||||
show_source: true
|
||||
signature_crossrefs: true
|
||||
summary: true
|
||||
- markdown-exec
|
||||
#- gen-files:
|
||||
# scripts:
|
||||
# - docs/gen_ref_pages.py
|
||||
#- literate-nav:
|
||||
# nav_file: SUMMARY.md
|
||||
# nav_file: SUMMARY.md
|
||||
|
||||
1
setup.py
1
setup.py
@@ -57,6 +57,7 @@ setup(name='tinygrad',
|
||||
"mkdocs-material",
|
||||
"mkdocstrings[python]",
|
||||
"markdown-callouts",
|
||||
"markdown-exec[ansi]"
|
||||
],
|
||||
'testing_tf': [
|
||||
"tensorflow==2.15.1",
|
||||
|
||||
@@ -71,7 +71,13 @@ def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for
|
||||
def broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
|
||||
|
||||
class Tensor:
|
||||
"""A `Tensor` is a multi-dimensional matrix containing elements of a single data type."""
|
||||
"""
|
||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||
|
||||
```python exec="true" session="tensor"
|
||||
from tinygrad import Tensor
|
||||
```
|
||||
"""
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
training: ClassVar[bool] = False
|
||||
@@ -193,14 +199,40 @@ class Tensor:
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
return self._data().cast(self.dtype.fmt, self.shape)
|
||||
def item(self) -> ConstType:
|
||||
"""Returns the value of this tensor as a standard Python number."""
|
||||
"""
|
||||
Returns the value of this tensor as a standard Python number.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor(42)
|
||||
print(t.item())
|
||||
```
|
||||
"""
|
||||
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self._data().cast(self.dtype.fmt)[0]
|
||||
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
|
||||
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
||||
def tolist(self) -> Union[Sequence[ConstType], ConstType]: return self.data().tolist()
|
||||
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
|
||||
"""
|
||||
Returns the value of this tensor as a nested list.
|
||||
|
||||
Currently this only works for flattened tensors.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.tolist())
|
||||
```
|
||||
"""
|
||||
return self.data().tolist()
|
||||
def numpy(self) -> np.ndarray:
|
||||
"""
|
||||
Returns the value of this tensor as a numpy array.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
@@ -241,15 +273,48 @@ class Tensor:
|
||||
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
||||
def empty(*shape, **kwargs):
|
||||
"""
|
||||
Creates an empty tensor with the given shape.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3)
|
||||
print(t.shape)
|
||||
```
|
||||
"""
|
||||
return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_rng_counter: Optional[Tensor] = None
|
||||
@staticmethod
|
||||
def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
def manual_seed(seed=0):
|
||||
"""
|
||||
Sets the seed for random operations.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
print(Tensor._seed)
|
||||
```
|
||||
"""
|
||||
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values between the interval `[0, 1)`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.rand(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
if not THREEFRY.value:
|
||||
# for bfloat16, numpy rand passes buffer in float
|
||||
@@ -278,16 +343,74 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with the given value.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.full((2, 3), 42)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
||||
def zeros(*shape, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with zeros.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.zeros(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
||||
def ones(*shape, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with ones.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
"""
|
||||
If `stop` is not specified, creates a tensor with the given shape, filled with values from `0` to `start` with the given step size.
|
||||
|
||||
If `stop` is specified, creates a tensor with the given shape, filled with values from `start` to `stop` with the given step size.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5, 10)
|
||||
print(t.numpy())
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5, 10, 2)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
if stop is None: stop, start = start, 0
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
@@ -295,49 +418,200 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs):
|
||||
"""
|
||||
Creates an identity matrix of the given dimension.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.eye(3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value:ConstType, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the same shape as `tensor`, filled with the given value.
|
||||
If `dtype` is not specified, the dtype of `tensor` is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.ones(2, 3)
|
||||
t = Tensor.full_like(ot, 42)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
||||
def zeros_like(self, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the same shape as `tensor`, filled with zeros.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.ones(2, 3)
|
||||
t = Tensor.zeros_like(ot)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the same shape as `tensor`, filled with ones.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.zeros(2, 3)
|
||||
t = Tensor.ones_like(ot)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.full_like(1, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
||||
|
||||
@staticmethod
|
||||
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
|
||||
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random integer values from the interval `[low, high)`.
|
||||
If `dtype` is not specified, the default type is used.
|
||||
|
||||
You can pass in the `device` keyword argument to control device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy())
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a normal distribution with the given mean and standard deviation.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.normal(2, 3, mean=10, std=2)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values from a uniform distribution with the given lower and upper bounds.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.uniform(2, 3, low=2, high=10)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float)
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with random values
|
||||
from a uniform distribution with a mean of zero and a standard deviation of `(prod(shape)**-0.5`.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.scaled_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
|
||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.glorot_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.kaiming_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
"""
|
||||
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.kaiming_normal(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user