remove numpy from dtype (#4969)

replaced all dtype.np with _to_np_dtype defined in tensor.py.

after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
This commit is contained in:
chenyu
2024-06-14 15:38:45 -04:00
committed by GitHub
parent 62dc36d371
commit 67e8df4969
17 changed files with 73 additions and 63 deletions

View File

@@ -20,6 +20,7 @@ from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
from tinygrad.engine.schedule import ScheduleItem, create_schedule, memory_planner
from tinygrad.ops import LoadOps
from tinygrad.tensor import _to_np_dtype
Device.DEFAULT = "GPU"
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
@@ -93,7 +94,7 @@ def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tenso
output = eis[-1].bufs[0]
for ei in eis: ei.run()
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=output.dtype.np)
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!")

View File

@@ -4,6 +4,7 @@ import os
import numpy as np
from pathlib import Path
from tinygrad import nn, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load
from extra.models.resnet import ResNet
@@ -1217,7 +1218,7 @@ def to_image_list(tensors, size_divisible=32):
max_size = tuple(max_size)
batch_shape = (len(tensors),) + max_size
batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np)
batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
for img, pad_img in zip(tensors, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()

View File

@@ -4,6 +4,7 @@ import importlib
from functools import lru_cache
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from tinygrad.dtype import ConstType
from typing import List, Dict, Union
@@ -76,7 +77,8 @@ def get_run_onnx(onnx_model: ModelProto):
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims))
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy(),
requires_grad=False).reshape(tuple(inp.dims))
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:

View File

@@ -7,6 +7,7 @@ from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing
from tinygrad.tensor import _to_np_dtype
PUSH_PERMUTES = False
@@ -46,21 +47,21 @@ class TestInferenceMinKernels(unittest.TestCase):
@unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES")
def test_convnext(self):
model = ConvNeXt()
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(129):
model(img).realize()
def test_enet(self):
model = EfficientNet(getenv("ENET_NUM", 0), has_se=False)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(51):
model.forward(img).realize()
def test_enet_se(self):
model = EfficientNet(getenv("ENET_NUM", 0), has_se=True)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
img = Tensor.randn(1, 3, 224, 224)
# TODO: this seems very high
with CLCache(115):
@@ -68,14 +69,14 @@ class TestInferenceMinKernels(unittest.TestCase):
def test_resnet(self):
model = ResNet18()
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(23):
model.forward(img).realize()
def test_vit(self):
model = ViT(embed_dim=192, num_heads=3)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(209) as cache: # NOTE: this is way too high
out = model.forward(img)
@@ -87,7 +88,7 @@ class TestInferenceMinKernels(unittest.TestCase):
from examples.llama import Transformer
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype)))
inp = Tensor([[1,2,3,4]])
with CLCache(100):
model(inp, 0).realize()

View File

@@ -5,6 +5,7 @@ from collections import defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_linearizer_actions, bufs_from_lin
@@ -30,15 +31,15 @@ def get_fuzz_rawbufs(lin):
with Context(DEBUG=0):
for rawbuf in rawbufs[1:]:
if dtypes.is_unsigned(rawbuf.dtype):
data = np.random.randint(0, 100, size=rawbuf.size, dtype=rawbuf.dtype.np)
data = np.random.randint(0, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
elif dtypes.is_int(rawbuf.dtype):
data = np.random.randint(-100, 100, size=rawbuf.size, dtype=rawbuf.dtype.np)
data = np.random.randint(-100, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
elif rawbuf.dtype == dtypes.bool:
data = np.random.choice([True, False], size=rawbuf.size)
elif rawbuf.dtype == dtypes.half:
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=rawbuf.dtype.np)
data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
else:
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=rawbuf.dtype.np)
data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer())
return rawbufs
@@ -92,7 +93,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy()
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
@@ -100,7 +101,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
try:
if not has_bf16:
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype))
np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol)
except AssertionError as e:
if DEBUG >= 2:

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem, ScheduleItem
from tinygrad.ops import LoadOps
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
@@ -60,8 +60,8 @@ def fuzz_schedule(outs:List[LazyBuffer]):
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np)
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], out.dtype.np), atol=1e-2, rtol=1e-2)
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
except Exception as e:
print(f"FAILED FOR {out}")
raise e

View File

@@ -6,6 +6,7 @@ from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import _to_np_dtype
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
paths: List[List[UOp]] = []
@@ -27,7 +28,7 @@ class UOpsFuzzerRunner(CompiledRunner):
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow"))
super().__call__(rawbufs, var_vals, wait)
ground_truth = {x:np.frombuffer(x.as_buffer(), x.dtype.np) for x in rawbufs}
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
for i, path in enumerate(self.p.uops.fuzz_paths):
# setup prg
@@ -43,7 +44,7 @@ class UOpsFuzzerRunner(CompiledRunner):
super().__call__(rawbufs, var_vals, wait)
for i, x in enumerate(rawbufs):
try:
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[x], atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)), ground_truth[x], atol=1e-6, rtol=1e-6)
if DEBUG >= 2: print(colored(name, "green"))
except AssertionError as e:
print(colored(name, "red"))

View File

@@ -4,6 +4,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin
from test.external.fuzz_linearizer import get_fuzz_rawbufs
from tinygrad.engine.search import bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.tensor import _to_np_dtype
import numpy as np
# move to helpers?
@@ -63,8 +64,8 @@ if __name__ == "__main__":
failed = True
if not failed and not has_bf16:
curesult = np.frombuffer(test_cubufs[0].as_buffer(), test_cubufs[0].dtype.np)
nvresult = np.frombuffer(test_nvbufs[0].as_buffer(), test_nvbufs[0].dtype.np)
curesult = np.frombuffer(test_cubufs[0].as_buffer(), _to_np_dtype(test_cubufs[0].dtype))
nvresult = np.frombuffer(test_nvbufs[0].as_buffer(), _to_np_dtype(test_nvbufs[0].dtype))
np.testing.assert_allclose(curesult, nvresult, rtol=1e-2, atol=1e-2)
average_tm_cuda += min(tm_cuda)

View File

@@ -1,6 +1,7 @@
import sys
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
@@ -41,9 +42,9 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
def rand_for_dtype(dt:DType, size:int):
if dtypes.is_unsigned(dt):
return np.random.randint(0, 100, size=size, dtype=dt.np)
return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt))
elif dtypes.is_int(dt):
return np.random.randint(-100, 100, size=size, dtype=dt.np)
return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt))
elif dt == dtypes.bool:
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(dt.np)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))

View File

@@ -5,6 +5,7 @@ from typing import Any, List
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype
@@ -51,10 +52,10 @@ def _test_cast(a:Tensor, target_dtype:DType):
# TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084
return
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -66,7 +67,8 @@ class TestDType(unittest.TestCase):
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
@@ -104,7 +106,7 @@ class TestDType(unittest.TestCase):
def test_dtypes_fields(self):
fields = dtypes.fields()
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))

View File

@@ -9,6 +9,7 @@ from tinygrad.helpers import CI, getenv
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import UnaryOps
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -59,7 +60,7 @@ class ht:
def universal_test(a, b, dtype, op):
if not isinstance(op, tuple): op = (op, op)
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np), np.array([b]).astype(dtype.np))
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype)))
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
else: np.testing.assert_equal(tensor_value, numpy_value)
@@ -70,7 +71,7 @@ def universal_test_unary(a, dtype, op):
ast = sched[-1].ast[0]
run_schedule(sched)
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)))
if dtype in dtypes_float:
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
@@ -80,16 +81,16 @@ def universal_test_unary(a, dtype, op):
def universal_test_cast(a, in_dtype, dtype):
tensor_value = Tensor([a], dtype=in_dtype).cast(dtype)
numpy_value = np.array([a]).astype(dtype.np)
numpy_value = np.array([a]).astype(_to_np_dtype(dtype))
np.testing.assert_equal(tensor_value.numpy(), numpy_value)
def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
if not isinstance(op1, tuple): op1 = (op1, op1)
if not isinstance(op2, tuple): op2 = (op2, op2)
at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2)
an, bn, cn = np.array([a]).astype(d1.np), np.array([b]).astype(d1.np), np.array([c]).astype(d2.np)
an, bn, cn = np.array([a]).astype(_to_np_dtype(d1)), np.array([b]).astype(_to_np_dtype(d1)), np.array([c]).astype(_to_np_dtype(d2))
tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy()
numpy_value = op2[1](op1[1](an, bn).astype(d2.np), cn)
numpy_value = op2[1](op1[1](an, bn).astype(_to_np_dtype(d2)), cn)
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7)
class TestDTypeALU(unittest.TestCase):

View File

@@ -12,7 +12,7 @@ from tinygrad.renderer import TensorCore
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.engine.graph import print_tree
@@ -533,12 +533,12 @@ class TestLinearizer(unittest.TestCase):
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = CompiledRunner(k.to_program())
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
# ensure the results for each choice of axis matches
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
# check that get_linearizer_actions produces all 9 options
@@ -999,31 +999,31 @@ def _helper_linearizer_opt_ast(realized_ast:Tuple[LazyOp, ...], real_bufs:List[B
if expected_color_size is not None:
assert (cs:=[(x,y) for x,y in zip(k.colors(), k.full_shape)]) == expected_color_size, f"expected={expected_color_size} got={cs}"
prg = get_prg(k)
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=buf.dtype.np).data) # Zero to check that all values are filled
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
k = Linearizer(*realized_ast)
lins.append(k)
prg = get_prg(k)
prg.exec(real_bufs)
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), buf.dtype.np).copy() for buf in outbufs]
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).copy() for buf in outbufs]
else:
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
k = Linearizer(*realized_ast)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=buf.dtype.np).data) # Zero to check that all values are filled
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Linearizer(*realized_ast), color_sizes[i] if i < len(color_sizes) else None)
return lins

View File

@@ -3,6 +3,7 @@ import numpy as np
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
if CI:
import warnings
@@ -66,7 +67,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals]
else:
np.random.seed(0)
np_data = [np.random.uniform(low=low, high=high, size=size).astype(dtypes.default_float.np) for size in shps]
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst

View File

@@ -320,7 +320,7 @@ class TestTinygrad(unittest.TestCase):
# Regression test for https://github.com/tinygrad/tinygrad/issues/1751
def test_copy_from_numpy_unaligned(self):
# 2**15 is the minimum for repro
arr = np.random.randn(2**15).astype(dtypes.float.np)
arr = np.random.randn(2**15).astype(np.float32)
fn = temp('test_copy_from_numpy_unaligned')
with open(fn, 'wb') as f: f.write(b't' + arr.tobytes())
with open(fn, "a+b") as f: memview = memoryview(mmap.mmap(f.fileno(), arr.nbytes + 1))

View File

@@ -1,7 +1,7 @@
from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
@@ -33,10 +33,10 @@ def _test_single_value(vals, op, dts):
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg([out])
prg.exec([buf]+buf2)
ret = np.empty(1, output_dtype.np)
ret = np.empty(1, _to_np_dtype(output_dtype))
buf.copyout(ret.data)
return ret[0]
@@ -50,7 +50,7 @@ def _test_single_value_const(vals, op, dts):
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out])
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
ret = np.empty(1, _to_np_dtype(output_dtype))
buf.copyout(ret.data)
return ret[0]
@@ -62,7 +62,7 @@ def _test_uops_result(output_dtype, uops, res):
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out], print=True)
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
ret = np.empty(1, _to_np_dtype(output_dtype))
buf.copyout(ret.data)
return ret[0]

View File

@@ -1,6 +1,5 @@
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from dataclasses import dataclass
import numpy as np # TODO: remove numpy
import functools
from tinygrad.helpers import getenv
@@ -18,9 +17,6 @@ class DType:
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# TODO: someday this will be removed with the "remove numpy" project
@property
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
# dependent typing?
@dataclass(frozen=True, repr=False)
@@ -47,8 +43,6 @@ class dtypes:
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def from_py(x) -> DType:
if x.__class__ is float: return dtypes.default_float
if x.__class__ is int: return dtypes.default_int

View File

@@ -43,6 +43,9 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x
else:
@@ -118,15 +121,15 @@ class Tensor:
else: data = _frompy(data, dtype)
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
else:
def _fromnp(x: np.ndarray) -> LazyBuffer:
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY")
ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
ret.buffer.allocate(x)
del ret.srcs
return ret
data = _fromnp(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
# by this point, it has to be a LazyBuffer
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
@@ -285,9 +288,9 @@ class Tensor:
```
"""
if self.dtype == dtypes.bfloat16: return self.float().numpy()
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
"""
@@ -2909,5 +2912,5 @@ def custom_random(out:Buffer):
Tensor._seed += 1
rng = np.random.default_rng(Tensor._seed)
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
out.copyin(rng_np_buffer.data)