mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
remove llvm_bf16_cast (#12075)
This commit is contained in:
@@ -249,8 +249,5 @@ def convert_from_gguf(weights:dict[str, Tensor], n_layers:int):
|
||||
return sd
|
||||
|
||||
def fix_bf16(weights:dict[Any, Tensor]):
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
|
||||
@@ -318,7 +318,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
with open(temp('dt_bf16_disk_write_read_bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('dt_bf16_disk_write_read_bf16')}")
|
||||
ct = t.llvm_bf16_cast(dtypes.float)
|
||||
ct = t.to(Device.DEFAULT).cast(dtypes.float)
|
||||
assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
|
||||
|
||||
def test_copy_from_disk(self):
|
||||
|
||||
@@ -4216,11 +4216,6 @@ class Tensor(MathTrait):
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
|
||||
# hack for devices that don't support bfloat16
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").cast(dtype)
|
||||
|
||||
def cast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Casts `self` to the given `dtype`.
|
||||
|
||||
Reference in New Issue
Block a user