mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
nn.state docs cleanup (#8332)
* doc cleanup * extension cleanup * manual definition * bring back accept_filename for gguf_load --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -29,5 +29,12 @@
|
||||
::: tinygrad.nn.state.get_state_dict
|
||||
::: tinygrad.nn.state.get_parameters
|
||||
::: tinygrad.nn.state.load_state_dict
|
||||
::: tinygrad.nn.state.tar_extract
|
||||
options:
|
||||
show_signature: false
|
||||
separate_signature: false
|
||||
::: tinygrad.nn.state.torch_load
|
||||
options:
|
||||
show_signature: false
|
||||
separate_signature: false
|
||||
::: tinygrad.nn.state.gguf_load
|
||||
|
||||
@@ -43,14 +43,14 @@ def accept_filename(func: Callable[[Tensor], T]) -> Callable[[Union[Tensor, str,
|
||||
@accept_filename
|
||||
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
||||
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
|
||||
"""
|
||||
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
||||
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the state_dict.
|
||||
Loads a .safetensor file, returning the `state_dict`.
|
||||
|
||||
```python
|
||||
state_dict = nn.state.safe_load("test.safetensor")
|
||||
@@ -63,7 +63,7 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
|
||||
|
||||
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
|
||||
"""
|
||||
Saves a state_dict to disk in a .safetensor file with optional metadata.
|
||||
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
|
||||
|
||||
```python
|
||||
t = Tensor([1, 2, 3])
|
||||
@@ -87,7 +87,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
|
||||
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
Returns a state_dict of the object, with optional prefix.
|
||||
Returns a `state_dict` of the object, with optional prefix.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
class Net:
|
||||
@@ -126,7 +126,7 @@ def get_parameters(obj) -> list[Tensor]:
|
||||
|
||||
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
|
||||
"""
|
||||
Loads a state_dict into a model.
|
||||
Loads a `state_dict` into a model.
|
||||
|
||||
```python
|
||||
class Net:
|
||||
@@ -162,7 +162,11 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
@accept_filename
|
||||
def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
||||
```python
|
||||
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
|
||||
```
|
||||
|
||||
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
|
||||
|
||||
```python
|
||||
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
|
||||
@@ -176,7 +180,11 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
@accept_filename
|
||||
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
Loads a torch .pth file from disk.
|
||||
```python
|
||||
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
|
||||
```
|
||||
|
||||
Loads a torch .pth file, returning the `state_dict`.
|
||||
|
||||
```python
|
||||
state_dict = nn.state.torch_load("test.pth")
|
||||
@@ -294,13 +302,14 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
@accept_filename
|
||||
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
||||
"""
|
||||
Loads a gguf file from a tensor.
|
||||
Loads a .gguf file, returning the `kv_data` and `state_dict`.
|
||||
|
||||
```python
|
||||
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
|
||||
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
|
||||
kv_data, state_dict = gguf_load(gguf_tensor)
|
||||
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
|
||||
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
|
||||
```
|
||||
|
||||
NOTE: The provided tensor must be on a device that supports execution.
|
||||
"""
|
||||
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
|
||||
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]
|
||||
|
||||
Reference in New Issue
Block a user