Files
tinygrad/datasets/preprocess_imagenet.py
wozeparrot 2fd2fb6380 int8/uint8 support (#837)
* feat: int8 support

* feat: uint8 support

* feat: int8 tests

* fix: fix uint8 on clang

* feat: test casting between int8/uint8/float16/float32

* clean: way cleaner dtype tests

* feat: preprocess_imagenet using the correct dtype

* feat: add test for overflow between uint8 and int8
2023-05-28 23:15:06 -07:00

23 lines
690 B
Python

from tinygrad.helpers import dtypes
from tinygrad.tensor import Tensor
from datasets.imagenet import iterate, get_val_files
if __name__ == "__main__":
#sz = len(get_val_files())
sz = 32*100
X,Y = None, None
idx = 0
for x,y in iterate(shuffle=False):
print(x.shape, y.shape, x.dtype, y.dtype)
assert x.shape[0] == y.shape[0]
bs = x.shape[0]
if X is None:
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
print(X.shape, Y.shape)
X[idx:idx+bs].assign(x)
Y[idx:idx+bs].assign(y)
idx += bs
if idx >= sz: break