diff --git a/extra/datasets/__init__.py b/extra/datasets/__init__.py index 742a8ac9b3..7636e2086d 100644 --- a/extra/datasets/__init__.py +++ b/extra/datasets/__init__.py @@ -7,9 +7,9 @@ def fetch_mnist(tensors=False): parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy() BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32) - Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:] + Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:].astype(np.int8) X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32) - Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:] + Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:].astype(np.int8) if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test) else: return X_train, Y_train, X_test, Y_test diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 672bf9c137..a6494e3771 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,5 +1,4 @@ import torch -import numpy as np from typing import Dict, Callable from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op from tinygrad.device import Interpreted, Allocator @@ -14,7 +13,6 @@ type_map = {torch.bool: dtypes.bool, inverse_type_map = {v: k for k,v in type_map.items()} # TODO: should unsupported types fail instead of implicit conversion? inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64}) -def np_type_cvt(t): return {np.uint32: np.int32, np.uint64: np.int64}.get(t, t) def as_strided(x, arg): shape, stride, offset = arg @@ -26,9 +24,7 @@ def as_strided(x, arg): return torch.as_strided(x, shape, stride, offset) torch_fxn_for_op: Dict[Op, Callable] = { - # TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8 - #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), - BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=np_type_cvt(dtype.np))).to(device), + BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt, UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),