From 67e8df4969f99e9eefce2efa86afdf7829220bcd Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 14 Jun 2024 15:38:45 -0400 Subject: [PATCH] remove numpy from dtype (#4969) replaced all dtype.np with _to_np_dtype defined in tensor.py. after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer --- examples/openpilot/compile2.py | 3 ++- extra/models/mask_rcnn.py | 3 ++- extra/onnx.py | 4 +++- test/external/external_test_opt.py | 13 +++++++------ test/external/fuzz_linearizer.py | 13 +++++++------ test/external/fuzz_schedule.py | 6 +++--- test/external/fuzz_uops.py | 5 +++-- test/external/speed_compare_cuda_nv.py | 5 +++-- test/helpers.py | 7 ++++--- test/test_dtype.py | 10 ++++++---- test/test_dtype_alu.py | 11 ++++++----- test/test_linearizer.py | 20 ++++++++++---------- test/test_ops.py | 3 ++- test/test_tensor.py | 2 +- test/test_uops.py | 10 +++++----- tinygrad/dtype.py | 6 ------ tinygrad/tensor.py | 15 +++++++++------ 17 files changed, 73 insertions(+), 63 deletions(-) diff --git a/examples/openpilot/compile2.py b/examples/openpilot/compile2.py index bfa9bac0a7..56a5f07065 100644 --- a/examples/openpilot/compile2.py +++ b/examples/openpilot/compile2.py @@ -20,6 +20,7 @@ from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner from tinygrad.engine.schedule import ScheduleItem, create_schedule, memory_planner from tinygrad.ops import LoadOps +from tinygrad.tensor import _to_np_dtype Device.DEFAULT = "GPU" def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: @@ -93,7 +94,7 @@ def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tenso output = eis[-1].bufs[0] for ei in eis: ei.run() - new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=output.dtype.np) + new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype)) np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2) print("semi-thneed self-test passed!") diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index 0ed0d3077c..125e3db585 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -4,6 +4,7 @@ import os import numpy as np from pathlib import Path from tinygrad import nn, Tensor, dtypes +from tinygrad.tensor import _to_np_dtype from tinygrad.helpers import get_child, fetch from tinygrad.nn.state import torch_load from extra.models.resnet import ResNet @@ -1217,7 +1218,7 @@ def to_image_list(tensors, size_divisible=32): max_size = tuple(max_size) batch_shape = (len(tensors),) + max_size - batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np) + batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype)) for img, pad_img in zip(tensors, batched_imgs): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy() diff --git a/extra/onnx.py b/extra/onnx.py index de7bf39394..4d0a587f88 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -4,6 +4,7 @@ import importlib from functools import lru_cache import numpy as np from tinygrad import Tensor, dtypes, Device +from tinygrad.tensor import _to_np_dtype from tinygrad.helpers import getenv, DEBUG, CI, OSX from tinygrad.dtype import ConstType from typing import List, Dict, Union @@ -76,7 +77,8 @@ def get_run_onnx(onnx_model: ModelProto): if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data): return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims)) if len(inp.raw_data) > 0: - return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims)) + return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy(), + requires_grad=False).reshape(tuple(inp.dims)) return Tensor(None, requires_grad=False) def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]: diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 0f895a6168..c81b100bef 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -7,6 +7,7 @@ from tinygrad import GlobalCounters, Tensor, Device from tinygrad.helpers import getenv from tinygrad.nn.state import get_parameters from tinygrad.engine.realize import capturing +from tinygrad.tensor import _to_np_dtype PUSH_PERMUTES = False @@ -46,21 +47,21 @@ class TestInferenceMinKernels(unittest.TestCase): @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") def test_convnext(self): model = ConvNeXt() - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) img = Tensor.randn(1, 3, 224, 224) with CLCache(129): model(img).realize() def test_enet(self): model = EfficientNet(getenv("ENET_NUM", 0), has_se=False) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) img = Tensor.randn(1, 3, 224, 224) with CLCache(51): model.forward(img).realize() def test_enet_se(self): model = EfficientNet(getenv("ENET_NUM", 0), has_se=True) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) img = Tensor.randn(1, 3, 224, 224) # TODO: this seems very high with CLCache(115): @@ -68,14 +69,14 @@ class TestInferenceMinKernels(unittest.TestCase): def test_resnet(self): model = ResNet18() - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) img = Tensor.randn(1, 3, 224, 224) with CLCache(23): model.forward(img).realize() def test_vit(self): model = ViT(embed_dim=192, num_heads=3) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) img = Tensor.randn(1, 3, 224, 224) with CLCache(209) as cache: # NOTE: this is way too high out = model.forward(img) @@ -87,7 +88,7 @@ class TestInferenceMinKernels(unittest.TestCase): from examples.llama import Transformer args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=_to_np_dtype(p.dtype))) inp = Tensor([[1,2,3,4]]) with CLCache(100): model(inp, 0).realize() diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 9e186310c2..7e78e4882c 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -5,6 +5,7 @@ from collections import defaultdict from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad import Tensor, Device, dtypes +from tinygrad.tensor import _to_np_dtype from tinygrad.codegen.linearizer import Linearizer, UOp from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.engine.search import get_linearizer_actions, bufs_from_lin @@ -30,15 +31,15 @@ def get_fuzz_rawbufs(lin): with Context(DEBUG=0): for rawbuf in rawbufs[1:]: if dtypes.is_unsigned(rawbuf.dtype): - data = np.random.randint(0, 100, size=rawbuf.size, dtype=rawbuf.dtype.np) + data = np.random.randint(0, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype)) elif dtypes.is_int(rawbuf.dtype): - data = np.random.randint(-100, 100, size=rawbuf.size, dtype=rawbuf.dtype.np) + data = np.random.randint(-100, 100, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype)) elif rawbuf.dtype == dtypes.bool: data = np.random.choice([True, False], size=rawbuf.size) elif rawbuf.dtype == dtypes.half: - data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=rawbuf.dtype.np) + data = np.random.uniform(-1, 1, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) else: - data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=rawbuf.dtype.np) + data = np.random.uniform(-10, 10, size=rawbuf.size).astype(dtype=_to_np_dtype(rawbuf.dtype)) rawbuf.copyin(Tensor(data).realize().lazydata.realized.as_buffer()) return rawbufs @@ -92,7 +93,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut unoptimized.required_optimizations() if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS": return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,) - ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy() + ground_truth = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)).copy() rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": @@ -100,7 +101,7 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut try: if not has_bf16: - result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np) + result = np.frombuffer(rawbufs[0].as_buffer(), _to_np_dtype(rawbufs[0].dtype)) np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol) except AssertionError as e: if DEBUG >= 2: diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 9a712e9e96..6cc2e87ab6 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -7,7 +7,7 @@ from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.lazy import LazyBuffer from tinygrad.engine.schedule import _graph_schedule, _LBScheduleItem, ScheduleItem from tinygrad.ops import LoadOps -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, _to_np_dtype ctx_vars = { MULTIOUTPUT: (0, 1) } @@ -60,8 +60,8 @@ def fuzz_schedule(outs:List[LazyBuffer]): si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0)) _exec_si(si, seed) for out in ps.outputs: - outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np) - try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], out.dtype.np), atol=1e-2, rtol=1e-2) + outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype)) + try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2) except Exception as e: print(f"FAILED FOR {out}") raise e diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index 05f0462d82..64e32271d7 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -6,6 +6,7 @@ from tinygrad.device import Buffer, Device from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import DEBUG, colored, getenv from tinygrad.shape.symbolic import Variable +from tinygrad.tensor import _to_np_dtype def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]): paths: List[List[UOp]] = [] @@ -27,7 +28,7 @@ class UOpsFuzzerRunner(CompiledRunner): if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow")) super().__call__(rawbufs, var_vals, wait) - ground_truth = {x:np.frombuffer(x.as_buffer(), x.dtype.np) for x in rawbufs} + ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs} for i, path in enumerate(self.p.uops.fuzz_paths): # setup prg @@ -43,7 +44,7 @@ class UOpsFuzzerRunner(CompiledRunner): super().__call__(rawbufs, var_vals, wait) for i, x in enumerate(rawbufs): try: - np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[x], atol=1e-6, rtol=1e-6) + np.testing.assert_allclose(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)), ground_truth[x], atol=1e-6, rtol=1e-6) if DEBUG >= 2: print(colored(name, "green")) except AssertionError as e: print(colored(name, "red")) diff --git a/test/external/speed_compare_cuda_nv.py b/test/external/speed_compare_cuda_nv.py index 102076f700..10c60e00d3 100644 --- a/test/external/speed_compare_cuda_nv.py +++ b/test/external/speed_compare_cuda_nv.py @@ -4,6 +4,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin from test.external.fuzz_linearizer import get_fuzz_rawbufs from tinygrad.engine.search import bufs_from_lin from tinygrad.engine.realize import CompiledRunner +from tinygrad.tensor import _to_np_dtype import numpy as np # move to helpers? @@ -63,8 +64,8 @@ if __name__ == "__main__": failed = True if not failed and not has_bf16: - curesult = np.frombuffer(test_cubufs[0].as_buffer(), test_cubufs[0].dtype.np) - nvresult = np.frombuffer(test_nvbufs[0].as_buffer(), test_nvbufs[0].dtype.np) + curesult = np.frombuffer(test_cubufs[0].as_buffer(), _to_np_dtype(test_cubufs[0].dtype)) + nvresult = np.frombuffer(test_nvbufs[0].as_buffer(), _to_np_dtype(test_nvbufs[0].dtype)) np.testing.assert_allclose(curesult, nvresult, rtol=1e-2, atol=1e-2) average_tm_cuda += min(tm_cuda) diff --git a/test/helpers.py b/test/helpers.py index e2e395b1bc..dfd57cee73 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,6 +1,7 @@ import sys import numpy as np from tinygrad import Tensor, Device, dtypes +from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import DType from tinygrad.nn.state import get_parameters @@ -41,9 +42,9 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): def rand_for_dtype(dt:DType, size:int): if dtypes.is_unsigned(dt): - return np.random.randint(0, 100, size=size, dtype=dt.np) + return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt)) elif dtypes.is_int(dt): - return np.random.randint(-100, 100, size=size, dtype=dt.np) + return np.random.randint(-100, 100, size=size, dtype=_to_np_dtype(dt)) elif dt == dtypes.bool: return np.random.choice([True, False], size=size) - return np.random.uniform(-10, 10, size=size).astype(dt.np) + return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) diff --git a/test/test_dtype.py b/test/test_dtype.py index 2988805f61..7dff745bce 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -5,6 +5,7 @@ from typing import Any, List from tinygrad.helpers import getenv, DEBUG, CI from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype from tinygrad import Device, Tensor, dtypes +from tinygrad.tensor import _to_np_dtype from hypothesis import given, settings, strategies as strat from test.helpers import is_dtype_supported, rand_for_dtype @@ -51,10 +52,10 @@ def _test_cast(a:Tensor, target_dtype:DType): # TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084 return - _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np))) + _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype)))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet") - _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist()) + _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist()) class TestDType(unittest.TestCase): DTYPE: Any = None @@ -66,7 +67,8 @@ class TestDType(unittest.TestCase): def setUp(self): if self.DTYPE is None: raise unittest.SkipTest("base class") - def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np)) + def test_to_np(self): + _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE))) def test_casts_to(self): list(map( lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), @@ -104,7 +106,7 @@ class TestDType(unittest.TestCase): def test_dtypes_fields(self): fields = dtypes.fields() self.assertTrue(all(isinstance(value, DType) for value in fields.values())) - self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None)) + self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None)) def test_resulting_and_init_dtypes_match(self): dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"])) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 802f40281a..38fab57d7f 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -9,6 +9,7 @@ from tinygrad.helpers import CI, getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule from tinygrad.ops import UnaryOps +from tinygrad.tensor import _to_np_dtype from test.helpers import is_dtype_supported settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) @@ -59,7 +60,7 @@ class ht: def universal_test(a, b, dtype, op): if not isinstance(op, tuple): op = (op, op) tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy() - numpy_value = op[1](np.array([a]).astype(dtype.np), np.array([b]).astype(dtype.np)) + numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype)), np.array([b]).astype(_to_np_dtype(dtype))) if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10) else: np.testing.assert_equal(tensor_value, numpy_value) @@ -70,7 +71,7 @@ def universal_test_unary(a, dtype, op): ast = sched[-1].ast[0] run_schedule(sched) tensor_value = out.numpy() - numpy_value = op[1](np.array([a]).astype(dtype.np)) + numpy_value = op[1](np.array([a]).astype(_to_np_dtype(dtype))) if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) else: np.testing.assert_equal(tensor_value, numpy_value) @@ -80,16 +81,16 @@ def universal_test_unary(a, dtype, op): def universal_test_cast(a, in_dtype, dtype): tensor_value = Tensor([a], dtype=in_dtype).cast(dtype) - numpy_value = np.array([a]).astype(dtype.np) + numpy_value = np.array([a]).astype(_to_np_dtype(dtype)) np.testing.assert_equal(tensor_value.numpy(), numpy_value) def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType): if not isinstance(op1, tuple): op1 = (op1, op1) if not isinstance(op2, tuple): op2 = (op2, op2) at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2) - an, bn, cn = np.array([a]).astype(d1.np), np.array([b]).astype(d1.np), np.array([c]).astype(d2.np) + an, bn, cn = np.array([a]).astype(_to_np_dtype(d1)), np.array([b]).astype(_to_np_dtype(d1)), np.array([c]).astype(_to_np_dtype(d2)) tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy() - numpy_value = op2[1](op1[1](an, bn).astype(d2.np), cn) + numpy_value = op2[1](op1[1](an, bn).astype(_to_np_dtype(d2)), cn) np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7) class TestDTypeALU(unittest.TestCase): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7cf9093765..8e8ad34299 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,7 +12,7 @@ from tinygrad.renderer import TensorCore from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.engine.graph import print_tree @@ -533,12 +533,12 @@ class TestLinearizer(unittest.TestCase): assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" prg = CompiledRunner(k.to_program()) - real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled + real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) - result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) + result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) # ensure the results for each choice of axis matches - if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) + if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15) # check that get_linearizer_actions produces all 9 options @@ -999,31 +999,31 @@ def _helper_linearizer_opt_ast(realized_ast:Tuple[LazyOp, ...], real_bufs:List[B if expected_color_size is not None: assert (cs:=[(x,y) for x,y in zip(k.colors(), k.full_shape)]) == expected_color_size, f"expected={expected_color_size} got={cs}" prg = get_prg(k) - for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=buf.dtype.np).data) # Zero to check that all values are filled + for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) for i, buf in enumerate(outbufs): - np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol) + np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol) # Get baseline if it is not provided, which is not optimized at all. k = Linearizer(*realized_ast) lins.append(k) prg = get_prg(k) prg.exec(real_bufs) - if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), buf.dtype.np).copy() for buf in outbufs] + if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).copy() for buf in outbufs] else: for i, buf in enumerate(outbufs): - np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol) + np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol) # Check correctness of handcoded optimiztions. k = Linearizer(*realized_ast) lins.append(k) k.hand_coded_optimizations() prg = get_prg(k) - for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=buf.dtype.np).data) # Zero to check that all values are filled + for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) for i, buf in enumerate(outbufs): - np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol) + np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol) for i, x in enumerate(opts): # Check custom transformations if any. check_opt(x, lambda: Linearizer(*realized_ast), color_sizes[i] if i < len(color_sizes) else None) return lins diff --git a/test/test_ops.py b/test/test_ops.py index e5b2b72c0b..5a437ab578 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,6 +3,7 @@ import numpy as np import torch from tinygrad.helpers import getenv, IMAGE, DEBUG, CI from tinygrad import Tensor, Device, dtypes +from tinygrad.tensor import _to_np_dtype if CI: import warnings @@ -66,7 +67,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False): ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] else: np.random.seed(0) - np_data = [np.random.uniform(low=low, high=high, size=size).astype(dtypes.default_float.np) for size in shps] + np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps] ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data] tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] return ts, tst diff --git a/test/test_tensor.py b/test/test_tensor.py index fa48c52092..870acc3585 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -320,7 +320,7 @@ class TestTinygrad(unittest.TestCase): # Regression test for https://github.com/tinygrad/tinygrad/issues/1751 def test_copy_from_numpy_unaligned(self): # 2**15 is the minimum for repro - arr = np.random.randn(2**15).astype(dtypes.float.np) + arr = np.random.randn(2**15).astype(np.float32) fn = temp('test_copy_from_numpy_unaligned') with open(fn, 'wb') as f: f.write(b't' + arr.tobytes()) with open(fn, "a+b") as f: memview = memoryview(mmap.mmap(f.fileno(), arr.nbytes + 1)) diff --git a/test/test_uops.py b/test/test_uops.py index ee29dd68c5..065d9a1ff7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple, Any, List import unittest, math import numpy as np -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device @@ -33,10 +33,10 @@ def _test_single_value(vals, op, dts): alu = uop(uops, UOps.ALU, output_dtype, loads, op) out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() - buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)] + buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) prg.exec([buf]+buf2) - ret = np.empty(1, output_dtype.np) + ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] @@ -50,7 +50,7 @@ def _test_single_value_const(vals, op, dts): buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) - ret = np.empty(1, output_dtype.np) + ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] @@ -62,7 +62,7 @@ def _test_uops_result(output_dtype, uops, res): buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out], print=True) prg.exec([buf]) - ret = np.empty(1, output_dtype.np) + ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 8287cb9278..9c1481f282 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -1,6 +1,5 @@ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union from dataclasses import dataclass -import numpy as np # TODO: remove numpy import functools from tinygrad.helpers import getenv @@ -18,9 +17,6 @@ class DType: assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}" return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self - # TODO: someday this will be removed with the "remove numpy" project - @property - def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None # dependent typing? @dataclass(frozen=True, repr=False) @@ -47,8 +43,6 @@ class dtypes: @staticmethod def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod - def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name] - @staticmethod def from_py(x) -> DType: if x.__class__ is float: return dtypes.default_float if x.__class__ is int: return dtypes.default_int diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5177a9f8ed..3db66be90c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -43,6 +43,9 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src) return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None) +def _from_np_dtype(npdtype:type) -> DType: return dtypes.fields()[np.dtype(npdtype).name] +def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None + def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer: if isinstance(x, bytes): ret, data = LazyBuffer.loadop(LoadOps.EMPTY, (len(x),), dtype, "PYTHON"), x else: @@ -118,15 +121,15 @@ class Tensor: else: data = _frompy(data, dtype) elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, np.ndarray): - if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) + if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) else: def _fromnp(x: np.ndarray) -> LazyBuffer: - ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "NPY") + ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") # fake realize ret.buffer.allocate(x) del ret.srcs return ret - data = _fromnp(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) + data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # by this point, it has to be a LazyBuffer if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): @@ -285,9 +288,9 @@ class Tensor: ``` """ if self.dtype == dtypes.bfloat16: return self.float().numpy() - assert self.dtype.np is not None, f"no np dtype for {self.dtype}" + assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" - return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape) + return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape) def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor: """ @@ -2909,5 +2912,5 @@ def custom_random(out:Buffer): Tensor._seed += 1 rng = np.random.default_rng(Tensor._seed) if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False) - else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False) + else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False) out.copyin(rng_np_buffer.data)