bfloat16 in LLVM (enough for llama 2) (#1293)

* add bf16 support to LLVM

* bf16 read works
This commit is contained in:
George Hotz
2023-07-19 20:18:32 -07:00
committed by GitHub
parent 74e63fe4ee
commit ca77d6cd72
6 changed files with 50 additions and 5 deletions

View File

@@ -63,10 +63,10 @@ class Tensor:
return
if data.__class__ is list:
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
if data.__class__ is np.ndarray:
data = cast(np.ndarray, data)
if isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(data)
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return