int8/uint8 support (#837)

* feat: int8 support

* feat: uint8 support

* feat: int8 tests

* fix: fix uint8 on clang

* feat: test casting between int8/uint8/float16/float32

* clean: way cleaner dtype tests

* feat: preprocess_imagenet using the correct dtype

* feat: add test for overflow between uint8 and int8
This commit is contained in:
wozeparrot
2023-05-29 02:15:06 -04:00
committed by GitHub
parent 2939e40b98
commit 2fd2fb6380
9 changed files with 94 additions and 65 deletions

View File

@@ -1,3 +1,4 @@
from tinygrad.helpers import dtypes
from tinygrad.tensor import Tensor
from datasets.imagenet import iterate, get_val_files
@@ -8,13 +9,12 @@ if __name__ == "__main__":
idx = 0
for x,y in iterate(shuffle=False):
print(x.shape, y.shape)
print(x.shape, y.shape, x.dtype, y.dtype)
assert x.shape[0] == y.shape[0]
bs = x.shape[0]
if X is None:
# TODO: need uint8 support
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x")
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y")
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
print(X.shape, Y.shape)
X[idx:idx+bs].assign(x)
Y[idx:idx+bs].assign(y)

View File

@@ -8,83 +8,110 @@ from tinygrad.tensor import Tensor, dtypes
# for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
class TestDtype(unittest.TestCase):
def test_half_to_np(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
def _test_to_np(self, a, np_dtype, target):
print(a)
na = a.numpy()
print(na, na.dtype, a.lazydata.realized)
assert na.dtype == np.float16
np.testing.assert_allclose(na, [1,2,3,4])
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
def test_half_add(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float16)
def test_half_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
def test_int8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
def _test_cast(self, a, target_dtype, target):
print(a)
b = a.cast(target_dtype)
print(b.numpy())
assert b.dtype == target_dtype
np.testing.assert_allclose(b.numpy(), target)
def test_float_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
def test_float_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
def test_float_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
def test_half_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
def test_half_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
def test_half_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
def test_int8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
def test_int8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
def test_int8_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
def test_uint8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
def test_uint8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
def test_uint8_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
def _test_add(self, a, b, target_dtype, target):
c = a+b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_mul(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float16)
def test_half_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_int8_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
def _test_mul(self, a, b, target_dtype, target):
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_matmul(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float16)
def test_half_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_int8_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
def _test_matmul(self, a, b, target_dtype, target):
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_upcast_float(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
print(a)
fa = a.float()
assert a.device == fa.device
assert a.requires_grad == fa.requires_grad
na = fa.numpy()
print(na, na.dtype)
assert na.dtype == np.float32
np.testing.assert_allclose(na, [1,2,3,4])
def test_half_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def test_int8_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
def test_downcast_float(self):
a = Tensor([1,2,3,4], dtype=dtypes.float32, requires_grad=False).half()
print(a)
ha = a.half()
assert a.device == ha.device
assert a.requires_grad == ha.requires_grad
na = ha.numpy()
print(na, na.dtype)
assert na.dtype == np.float16
np.testing.assert_allclose(na, [1,2,3,4])
def test_half_add_upcast(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float32)
def _test_add_upcast(self, a, b, target_dtype, target):
c = a+b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_mul_upcast(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float32)
def test_half_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_add_upcast_half(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def _test_mul_upcast(self, a, b, target_dtype, target):
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_matmul_upcast(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float32)
def test_half_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int8_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int8_mul_upcast_half(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def _test_matmul_upcast(self, a, b, target_dtype, target):
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def test_int8_to_uint8_negative(self):
a = Tensor([-1, -2, -3, -4], dtype=dtypes.int8)
print(a)
b = a.cast(dtypes.uint8)
print(b.numpy())
np.testing.assert_allclose(b.numpy(), [255, 254, 253, 252])
def test_uint8_to_int8_overflow(self):
a = Tensor([255, 254, 253, 252], dtype=dtypes.uint8)
print(a)
b = a.cast(dtypes.int8)
print(b.numpy())
np.testing.assert_allclose(b.numpy(), [-1, -2, -3, -4])
if __name__ == '__main__':
unittest.main()

View File

@@ -132,7 +132,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
assert newvar.ltype == LocalTypes.float, "const can't be float4"
# nan? inf?
val = f"{bufs[args.i].realized._buf}f"
val = f"{bufs[args.i].realized._buf}" + ("f" if bufs[args.i].dtype not in (dtypes.int8, dtypes.uint8) else "")
elif isinstance(bufs[args.i].dtype, ImageDType):
assert newvar.ltype == LocalTypes.float4, "image must be float4"
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")

View File

@@ -201,7 +201,7 @@ class Linearizer:
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
# float4 grouping (optional)
should_upcast = self.supports_float4 and self.bufs[i].dtype != dtypes.float16 and len(self.float4_axis(i)) == 1
should_upcast = self.supports_float4 and (self.bufs[i].dtype not in (dtypes.float16, dtypes.int8, dtypes.uint8)) and len(self.float4_axis(i)) == 1
if should_upcast:
store_offset_new = {}
for k,out_tokens in self._group_float4(i, store_offset).items():

View File

@@ -38,7 +38,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
module = ir.Module(name=__file__)
# create llvm function
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType()}[buf.dtype] for buf in bufs]
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8)}[buf.dtype] for buf in bufs]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
# force llvmlite to allow us to add function attribute then add the attribute

View File

@@ -59,8 +59,10 @@ class LazyNumpyArray:
class dtypes:
float16: Final[DType] = DType(0, 2, "half", np.float16)
float32: Final[DType] = DType(1, 4, "float", np.float32)
int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32)
int64: Final[DType] = DType(2, 8, "int64", np.int64)
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
@staticmethod
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]

View File

@@ -36,7 +36,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)())
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):

View File

@@ -5,7 +5,7 @@ from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
class ClangProgram:
def __init__(self, name:str, prg:str):
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n" + prg
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n#define uchar unsigned char\n" + prg
# TODO: is there a way to not write this to disk?
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if platform.system() == 'Darwin' else 'so'}"
# NOTE: --rtlib=compiler-rt fixes float16 on Linux, it defines __gnu_h2f_ieee and __gnu_f2h_ieee

View File

@@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int32: dtypes.int32, torch.int64: dtypes.int64}
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8}
inverse_type_map = {v:k for k,v in type_map.items()}
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{