mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move bf16 cast hack to Tensor.llvm_bf16_cast (#3788)
This commit is contained in:
@@ -34,7 +34,7 @@ if __name__ == "__main__":
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
|
||||
# fix bf16, TODO: check if device supports bf16
|
||||
def fix_bf16(weights): return {k:v.llvm().cast(dtypes.float16).to(Device.DEFAULT) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
def fix_bf16(weights): return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False)
|
||||
|
||||
@@ -117,7 +117,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP", "METAL"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
data = [-1, 1, 2]
|
||||
@@ -127,21 +127,19 @@ class TestBFloat16(unittest.TestCase):
|
||||
assert tnp.dtype == np.float32
|
||||
np.testing.assert_allclose(tnp, np.array(data))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_ones(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_eye(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
@@ -157,7 +155,7 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["HIP"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
def test_f16_to_bf16_conversion(self):
|
||||
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
|
||||
|
||||
@@ -216,7 +216,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI")
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
|
||||
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
# hack to "cast" f32 -> bf16
|
||||
@@ -224,9 +224,8 @@ class TestDiskTensor(unittest.TestCase):
|
||||
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
|
||||
with open(temp('bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm_bf16_cast(dtypes.float)
|
||||
assert t.numpy().tolist() == [9984., -1, -1000, -9984, 20]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -91,8 +91,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't CPU permute BF16"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
# TODO: BUG: a ".realize()" is needed here for 'GPU=1 python3 test/models/test_efficientnet.py TestEfficientNet.test_car'
|
||||
ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes).realize()
|
||||
|
||||
@@ -993,11 +993,11 @@ class Tensor:
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def cast(self, dtype:DType) -> Tensor:
|
||||
if self.dtype == dtype: return self
|
||||
def llvm_bf16_cast(self, dtype:DType):
|
||||
# hack for devices that don't support bfloat16
|
||||
if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
return mlops.Cast.apply(self, dtype=dtype)
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes"
|
||||
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
|
||||
Reference in New Issue
Block a user