feat: small things from default_threefry (#6708)

This commit is contained in:
wozeparrot
2024-09-24 17:00:47 +08:00
committed by GitHub
parent f2700ac58a
commit f932116e05
2 changed files with 2 additions and 3 deletions

View File

@@ -142,8 +142,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
if quantize == "int8": linear = Int8Linear
elif quantize == "nf4": linear = NF4Linear(64)
else: linear = nn.Linear
with Context(THREEFRY=0):
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
# load weights
if model_path.is_dir():

View File

@@ -17,7 +17,7 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
qy = (ny - qx * dy) * t
return qx, qy
# *** helper functions for bit manipulation ***
def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
def significand_bits(d:DType) -> int: return dtypes.finfo(d)[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]