diff --git a/examples/llama3.py b/examples/llama3.py index 1710823687..5d9ff40cd4 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -142,8 +142,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N if quantize == "int8": linear = Int8Linear elif quantize == "nf4": linear = NF4Linear(64) else: linear = nn.Linear - with Context(THREEFRY=0): - model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True) + model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True) # load weights if model_path.is_dir(): diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 9be9125535..13471a2d7f 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -17,7 +17,7 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]: qy = (ny - qx * dy) * t return qx, qy # *** helper functions for bit manipulation *** -def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d] +def significand_bits(d:DType) -> int: return dtypes.finfo(d)[1] def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d] def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]