nn.state docs cleanup (#8332)

* doc cleanup

* extension cleanup

* manual definition

* bring back accept_filename for gguf_load

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
leopf
2025-03-18 22:16:40 +01:00
committed by GitHub
parent 1ea4876dfa
commit e4dad99145
2 changed files with 27 additions and 11 deletions

View File

@@ -29,5 +29,12 @@
::: tinygrad.nn.state.get_state_dict
::: tinygrad.nn.state.get_parameters
::: tinygrad.nn.state.load_state_dict
::: tinygrad.nn.state.tar_extract
options:
show_signature: false
separate_signature: false
::: tinygrad.nn.state.torch_load
options:
show_signature: false
separate_signature: false
::: tinygrad.nn.state.gguf_load

View File

@@ -43,14 +43,14 @@ def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str,
@accept_filename
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
"""
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
"""
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
"""
Loads a .safetensor file from disk, returning the state_dict.
Loads a .safetensor file, returning the `state_dict`.
```python
state_dict = nn.state.safe_load("test.safetensor")
@@ -63,7 +63,7 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
"""
Saves a state_dict to disk in a .safetensor file with optional metadata.
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
```python
t = Tensor([1, 2, 3])
@@ -87,7 +87,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
"""
Returns a state_dict of the object, with optional prefix.
Returns a `state_dict` of the object, with optional prefix.
```python exec="true" source="above" session="tensor" result="python"
class Net:
@@ -126,7 +126,7 @@ def get_parameters(obj) -> list[Tensor]:
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
"""
Loads a state_dict into a model.
Loads a `state_dict` into a model.
```python
class Net:
@@ -162,7 +162,11 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
@accept_filename
def tar_extract(t: Tensor) -> dict[str, Tensor]:
"""
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
```python
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
```python
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
@@ -176,7 +180,11 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
Loads a torch .pth file from disk.
```python
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
```
Loads a torch .pth file, returning the `state_dict`.
```python
state_dict = nn.state.torch_load("test.pth")
@@ -294,13 +302,14 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
@accept_filename
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
"""
Loads a gguf file from a tensor.
Loads a .gguf file, returning the `kv_data` and `state_dict`.
```python
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
kv_data, state_dict = gguf_load(gguf_tensor)
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
```
NOTE: The provided tensor must be on a device that supports execution.
"""
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]