mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Tuple -> tuple, List -> list [pr] (#8936)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple, Union, Optional, Dict, Any
|
||||
from typing import Union, Optional, Any
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
@@ -15,7 +15,7 @@ def complex_mult(A, c, d):
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
@@ -181,7 +181,7 @@ class Transformer:
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
|
||||
def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
@@ -207,7 +207,7 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
|
||||
def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
|
||||
def convert_from_gguf(weights:dict[str, Tensor], model: Transformer):
|
||||
keymap = {
|
||||
"token_embd.weight": "tok_embeddings.weight",
|
||||
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
@@ -222,7 +222,7 @@ def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
|
||||
sd["output.weight"] = weights["token_embd.weight"]
|
||||
return sd
|
||||
|
||||
def fix_bf16(weights:Dict[Any, Tensor]):
|
||||
def fix_bf16(weights:dict[Any, Tensor]):
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
|
||||
Reference in New Issue
Block a user