remove llvm_bf16_cast (#12075)

This commit is contained in:
nimlgen
2025-09-08 20:51:15 +03:00
committed by GitHub
parent 11213398b9
commit 9182948951
3 changed files with 3 additions and 11 deletions

View File

@@ -249,8 +249,5 @@ def convert_from_gguf(weights:dict[str, Tensor], n_layers:int):
return sd
def fix_bf16(weights:dict[Any, Tensor]):
if getenv("SUPPORT_BF16", 1):
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
# TODO: check if device supports bf16
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

View File

@@ -318,7 +318,7 @@ class TestDiskTensor(unittest.TestCase):
with open(temp('dt_bf16_disk_write_read_bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('dt_bf16_disk_write_read_bf16')}")
ct = t.llvm_bf16_cast(dtypes.float)
ct = t.to(Device.DEFAULT).cast(dtypes.float)
assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
def test_copy_from_disk(self):

View File

@@ -4216,11 +4216,6 @@ class Tensor(MathTrait):
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").cast(dtype)
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.