Tuple -> tuple, List -> list [pr] (#8936)

This commit is contained in:
chenyu
2025-02-06 14:21:19 -05:00
committed by GitHub
parent d5183e1584
commit a092b6395d
9 changed files with 43 additions and 47 deletions

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Union, Optional, Dict, Any
from typing import Union, Optional, Any
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
@@ -15,7 +15,7 @@ def complex_mult(A, c, d):
co = a*d + b*c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -181,7 +181,7 @@ class Transformer:
# *** helpers ***
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
@@ -207,7 +207,7 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
sd[keymap[k]] = v
return sd
def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
def convert_from_gguf(weights:dict[str, Tensor], model: Transformer):
keymap = {
"token_embd.weight": "tok_embeddings.weight",
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
@@ -222,7 +222,7 @@ def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
sd["output.weight"] = weights["token_embd.weight"]
return sd
def fix_bf16(weights:Dict[Any, Tensor]):
def fix_bf16(weights:dict[Any, Tensor]):
if getenv("SUPPORT_BF16", 1):
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}