some docstrings (#4201)

* feat: create and data access docstrings

* fix: linter

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
wozeparrot
2024-04-21 08:34:08 -04:00
committed by GitHub
parent 30fc1ad415
commit 4c99d49c4d
3 changed files with 289 additions and 13 deletions

View File

@@ -91,8 +91,9 @@ plugins:
show_source: true
signature_crossrefs: true
summary: true
- markdown-exec
#- gen-files:
# scripts:
# - docs/gen_ref_pages.py
#- literate-nav:
# nav_file: SUMMARY.md
# nav_file: SUMMARY.md

View File

@@ -57,6 +57,7 @@ setup(name='tinygrad',
"mkdocs-material",
"mkdocstrings[python]",
"markdown-callouts",
"markdown-exec[ansi]"
],
'testing_tf': [
"tensorflow==2.15.1",

View File

@@ -71,7 +71,13 @@ def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for
def broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
class Tensor:
"""A `Tensor` is a multi-dimensional matrix containing elements of a single data type."""
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor"
from tinygrad import Tensor
```
"""
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
@@ -193,14 +199,40 @@ class Tensor:
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._data().cast(self.dtype.fmt, self.shape)
def item(self) -> ConstType:
"""Returns the value of this tensor as a standard Python number."""
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert self.numel() == 1, "must have one element for item"
return self._data().cast(self.dtype.fmt)[0]
# TODO: should be Tensor.tolist() -> Union[List[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
def tolist(self) -> Union[Sequence[ConstType], ConstType]: return self.data().tolist()
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
"""
Returns the value of this tensor as a nested list.
Currently this only works for flattened tensors.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.tolist())
```
"""
return self.data().tolist()
def numpy(self) -> np.ndarray:
"""
Returns the value of this tensor as a numpy array.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.numpy())
```
"""
if self.dtype == dtypes.bfloat16: return self.float().numpy()
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
@@ -241,15 +273,48 @@ class Tensor:
return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
def empty(*shape, **kwargs):
"""
Creates an empty tensor with the given shape.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shape)
```
"""
return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs)
_seed: int = int(time.time())
_rng_counter: Optional[Tensor] = None
@staticmethod
def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
def manual_seed(seed=0):
"""
Sets the seed for random operations.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor._seed)
```
"""
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
@staticmethod
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs):
"""
Creates a tensor with the given shape, filled with random values between the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
if not THREEFRY.value:
# for bfloat16, numpy rand passes buffer in float
@@ -278,16 +343,74 @@ class Tensor:
@staticmethod
def full(shape:Tuple[sint, ...], fill_value:ConstType, **kwargs):
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.full((2, 3), 42)
print(t.numpy())
```
"""
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
def zeros(*shape, **kwargs):
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 3)
print(t.numpy())
```
"""
return Tensor.full(argfix(*shape), 0.0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
def ones(*shape, **kwargs):
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(t.numpy())
```
"""
return Tensor.full(argfix(*shape), 1.0, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
"""
If `stop` is not specified, creates a tensor with the given shape, filled with values from `0` to `start` with the given step size.
If `stop` is specified, creates a tensor with the given shape, filled with values from `start` to `stop` with the given step size.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(5)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(5, 10)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(5, 10, 2)
print(t.numpy())
```
"""
if stop is None: stop, start = start, 0
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
@@ -295,49 +418,200 @@ class Tensor:
@staticmethod
def eye(dim:int, **kwargs):
"""
Creates an identity matrix of the given dimension.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.eye(3)
print(t.numpy())
```
"""
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value:ConstType, **kwargs):
"""
Creates a tensor with the same shape as `tensor`, filled with the given value.
If `dtype` is not specified, the dtype of `tensor` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
ot = Tensor.ones(2, 3)
t = Tensor.full_like(ot, 42)
print(t.numpy())
```
"""
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
def zeros_like(self, **kwargs):
"""
Creates a tensor with the same shape as `tensor`, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
ot = Tensor.ones(2, 3)
t = Tensor.zeros_like(ot)
print(t.numpy())
```
"""
return self.full_like(0, **kwargs)
def ones_like(self, **kwargs):
"""
Creates a tensor with the same shape as `tensor`, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
ot = Tensor.zeros(2, 3)
t = Tensor.ones_like(ot)
print(t.numpy())
```
"""
return self.full_like(1, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
"""
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values from the interval `[low, high)`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randint(2, 3, low=5, high=10)
print(t.numpy())
"""
return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given mean and standard deviation.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.normal(2, 3, mean=10, std=2)
print(t.numpy())
```
"""
return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution with the given lower and upper bounds.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.uniform(2, 3, low=2, high=10)
print(t.numpy())
```
"""
dtype = kwargs.pop("dtype", dtypes.default_float)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values
from a uniform distribution with a mean of zero and a standard deviation of `(prod(shape)**-0.5`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.scaled_uniform(2, 3)
print(t.numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.glorot_uniform(2, 3)
print(t.numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.kaiming_uniform(2, 3)
print(t.numpy())
```
"""
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.kaiming_normal(2, 3)
print(t.numpy())
```
"""
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)