From 0723f26c80b3474dcb288c78df82dd586518a740 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 18 Dec 2023 12:21:44 -0500 Subject: [PATCH] dtypes.default_float and dtypes.default_int (#2824) --- examples/hlb_cifar10.py | 15 +-- extra/onnx_ops.py | 4 +- test/external/external_test_jit_on_models.py | 6 +- test/models/test_real_world.py | 8 +- test/test_dtype.py | 112 +++++++++++-------- test/test_hip_rdna3.py | 9 +- test/test_tensor.py | 4 +- tinygrad/helpers.py | 6 +- tinygrad/tensor.py | 26 ++--- 9 files changed, 108 insertions(+), 82 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index a91319d3ee..a4e6d6eea7 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -12,7 +12,7 @@ if __name__ == "__main__": # https://siboehm.com/articles/22/CUDA-MMM import random, time import numpy as np -from typing import Any, Dict, Optional, SupportsIndex, Type, Union +from typing import Any, Dict, Optional, SupportsIndex from extra.datasets import fetch_cifar, cifar_mean, cifar_std from tinygrad import nn from tinygrad.nn.state import get_state_dict @@ -27,10 +27,10 @@ from tinygrad.jit import TinyJit BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000) if getenv("HALF", 0): - Tensor.default_type = dtypes.float16 - np_dtype: Type[Union[np.float16, np.float32]] = np.float16 + dtypes.default_float = dtypes.float16 + np_dtype = np.float16 else: - Tensor.default_type = dtypes.float32 + dtypes.default_float = dtypes.float32 np_dtype = np.float32 class BatchNorm(nn.BatchNorm2d): @@ -52,13 +52,13 @@ class ConvGroup: x = x.max_pool2d(2) x = x.float() x = self.norm1(x) - x = x.cast(Tensor.default_type) + x = x.cast(dtypes.default_float) x = x.gelu() residual = x x = self.conv2(x) x = x.float() x = self.norm2(x) - x = x.cast(Tensor.default_type) + x = x.cast(dtypes.default_float) x = x.gelu() return x + residual @@ -277,7 +277,8 @@ def train_cifar(): X_train = pad_reflect(X_train, size=hyp['net']['pad_amount']) # Convert data and labels to the default dtype - X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type) + X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float) + X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float) # parse the training params into bias and non-bias params_dict = get_state_dict(model) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index fb8bc38fc0..9b8b0d5c6a 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -482,8 +482,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)] output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)] scales_ = [os/xs for xs, os in zip(X.shape, output_shape)] - x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type) - y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type) + x_out = Tensor.arange(output_shape[-1]).cast(dtypes.default_float) + y_out = Tensor.arange(output_shape[-2]).cast(dtypes.default_float) if mode == "nearest": x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi) x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1]) diff --git a/test/external/external_test_jit_on_models.py b/test/external/external_test_jit_on_models.py index 0eaff63c18..45f6be53fc 100644 --- a/test/external/external_test_jit_on_models.py +++ b/test/external/external_test_jit_on_models.py @@ -15,8 +15,8 @@ def helper_test_jitted_correctness(gen, train, train_jit): class TestJittedModels(unittest.TestCase): def test_jitted_tiny_llama(self): - old_type = Tensor.default_type - Tensor.default_type = dtypes.float16 + old_float = dtypes.default_float + dtypes.default_float = dtypes.float16 args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) @@ -26,7 +26,7 @@ class TestJittedModels(unittest.TestCase): @TinyJit def test_jit(t): return model(t, 0).realize() helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit) - Tensor.default_type = old_type + dtypes.default_float = old_float @unittest.skipUnless(not CI, "huge for CI") def test_jitted_stable_diffusion(self): diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 84e6bea300..c4f1f0edd3 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -40,11 +40,11 @@ class TestRealWorld(unittest.TestCase): gc.collect() global global_mem_used global_mem_used = GlobalCounters.mem_used - self.old_type = Tensor.default_type + self.old_float = dtypes.default_float np.random.seed(2002) def tearDown(self): - Tensor.default_type = self.old_type + dtypes.default_float = self.old_float @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(CI, "too big for CI") @@ -68,7 +68,7 @@ class TestRealWorld(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1") def test_llama(self): - Tensor.default_type = dtypes.float16 + dtypes.default_float = dtypes.float16 args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"])) @@ -80,7 +80,7 @@ class TestRealWorld(unittest.TestCase): @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") def test_gpt2(self): - Tensor.default_type = dtypes.float16 + dtypes.default_float = dtypes.float16 args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000} model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"])) diff --git a/test/test_dtype.py b/test/test_dtype.py index 719999df25..ebe4ecd5ac 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -224,56 +224,78 @@ class TestHelpers(unittest.TestCase): assert dtype.vec(amt).scalar() == dtype class TestTypeSpec(unittest.TestCase): - def test_creation(self): - assert Tensor(True).dtype == dtypes.bool - assert Tensor(2).dtype == dtypes.int - assert Tensor(2.34).dtype == Tensor.default_type - assert Tensor([]).dtype == Tensor.default_type - assert Tensor([1]).dtype == dtypes.int - assert Tensor([1.1]).dtype == Tensor.default_type - assert Tensor([0, 1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16 + def setUp(self): + self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float + def tearDown(self): + dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float - def test_const_full(self): - assert Tensor.ones([2,3]).dtype == Tensor.default_type - assert Tensor.zeros([2,3]).dtype == Tensor.default_type - assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type - assert Tensor.full([2,3], 3).dtype == dtypes.int + def test_set_dtype_default(self): + dtypes.default_int = dtypes.int16 + assert dtypes.default_int == dtypes.int16 + dtypes.default_int = dtypes.int64 + assert dtypes.default_int == dtypes.int64 + dtypes.default_int = dtypes.int32 + assert dtypes.default_int == dtypes.int32 + dtypes.default_float = dtypes.float16 + assert dtypes.default_float == dtypes.float16 + dtypes.default_float = dtypes.float64 + assert dtypes.default_float == dtypes.float64 + + @given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_creation(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + assert Tensor(True).dtype == dtypes.bool + assert Tensor(2).dtype == dtypes.default_int + assert Tensor(2.34).dtype == dtypes.default_float + assert Tensor([]).dtype == dtypes.default_float + assert Tensor([1]).dtype == dtypes.default_int + assert Tensor([1.1]).dtype == dtypes.default_float + assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16 + + assert Tensor.eye(0).dtype == dtypes.default_float + assert Tensor.eye(3).dtype == dtypes.default_float + assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64 + + + @given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_full(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + + assert Tensor.ones([2,3]).dtype == dtypes.default_float + assert Tensor.zeros([2,3]).dtype == dtypes.default_float + assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float + assert Tensor.full([2,3], 3).dtype == dtypes.default_int assert Tensor.full([2,3], True).dtype == dtypes.bool + assert Tensor.zeros(3, 3).dtype == dtypes.default_float + assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + + assert Tensor.ones(3, 3).dtype == dtypes.default_float + assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + + assert Tensor.full((3, 3), 3).dtype == dtypes.default_int + assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float + assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64 + def test_reduce_0d_default(self): - assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type + assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int # requires reduceop acc fix - def test_arange(self): - assert Tensor.arange(5).dtype == dtypes.int32 - assert Tensor.arange(5.0).dtype == Tensor.default_type + @given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_arange(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + + assert Tensor.arange(5).dtype == dtypes.default_int + assert Tensor.arange(5.0).dtype == dtypes.default_float assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16 assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64 assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16 - assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type - assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type - - def test_zeros(self): - assert Tensor.zeros(3, 3).dtype == Tensor.default_type - assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16 - assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64 - - def test_ones(self): - assert Tensor.ones(3, 3).dtype == Tensor.default_type - assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16 - assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64 - - def test_full(self): - assert Tensor.full((3, 3), 3).dtype == dtypes.int - assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type - assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16 - assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64 - - def test_eye(self): - assert Tensor.eye(0).dtype == Tensor.default_type - assert Tensor.eye(3).dtype == Tensor.default_type - assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16 - assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64 + assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float + assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float core_types = list(DTYPES_DICT.values()) floats = [dt for dt in core_types if dtypes.is_float(dt)] @@ -331,10 +353,10 @@ class TestAutoCastType(unittest.TestCase): np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4) def test_broadcast_float(self): - assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == Tensor.default_type - assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == Tensor.default_type - assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == Tensor.default_type - assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == Tensor.default_type + assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float + assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float + assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float + assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16 assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16 assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32 diff --git a/test/test_hip_rdna3.py b/test/test_hip_rdna3.py index f5a06160ef..96c71e95ed 100644 --- a/test/test_hip_rdna3.py +++ b/test/test_hip_rdna3.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad import Tensor, Device -from tinygrad.helpers import dtypes +from tinygrad import Tensor, Device, dtypes from examples.beautiful_mnist import Model as MNIST from examples.hlb_cifar10 import SpeedyResNet @@ -23,8 +22,8 @@ class TestHIPCompilationRDNA(unittest.TestCase): output.numpy() def test_compile_hip_speedyresnet_hf(self): - old_default_type = Tensor.default_type - Tensor.default_type = dtypes.float16 + old_default_float = dtypes.default_float + dtypes.default_float = dtypes.float16 W = Tensor.rand(12,3,2,2) model = SpeedyResNet(W) @@ -33,7 +32,7 @@ class TestHIPCompilationRDNA(unittest.TestCase): output = model(input) output.numpy() - Tensor.default_type = old_default_type + dtypes.default_float = old_default_float if __name__ == "__main__": unittest.main() diff --git a/test/test_tensor.py b/test/test_tensor.py index 5984e28b1c..5fcf825eb8 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -252,7 +252,7 @@ class TestTinygrad(unittest.TestCase): # empty tensor defaults for arr in ([], [[[]]], [[],[]]): t = Tensor(arr) - assert t.dtype == Tensor.default_type + assert t.dtype == dtypes.default_float np.testing.assert_allclose(t.numpy(), np.array(arr)) # mixture of bool and int @@ -264,7 +264,7 @@ class TestTinygrad(unittest.TestCase): # mixture of bool, int and float for arr in ([[True,True],[3.,True]], [[0,1],[3.,4]], [[[0],[1]],[[3.],[4]]], [[[True],[1]],[[3.],[4]]]): t = Tensor(arr) - assert t.dtype == Tensor.default_type + assert t.dtype == dtypes.default_float np.testing.assert_allclose(t.numpy(), np.array(arr)) def test_tensor_list_shapes(self): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 260c95c0bb..a1e3038c7f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -181,6 +181,9 @@ class dtypes: @staticmethod def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32) + default_float: ClassVar[DType] = float32 + default_int: ClassVar[DType] = int32 + # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], @@ -198,7 +201,8 @@ def least_upper_dtype(*ds:DType) -> DType: def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32) # HACK: staticmethods are not callable in 3.8 so we have to compare the class -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if ( + not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)} INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()} class GlobalCounters: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2dcb80a7fc..45bc955804 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -46,7 +46,6 @@ class Tensor: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev no_grad: ClassVar[bool] = False - default_type: ClassVar[DType] = dtypes.float32 def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501 assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) @@ -61,14 +60,14 @@ class Tensor: self._ctx: Optional[Function] = None if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data) - elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.int32, device, data) - elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) + elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_int, device, data) + elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_float, device, data) elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) - elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np)) + elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np)) elif isinstance(data, list): if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool - elif d and all_int(d): dtype = dtype or dtypes.int32 - else: dtype = dtype or Tensor.default_type + elif d and all_int(d): dtype = dtype or dtypes.default_int + else: dtype = dtype or dtypes.default_float # NOTE: cast at the end for the types that do not have a numpy dtype data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype) elif isinstance(data, np.ndarray): @@ -153,7 +152,7 @@ class Tensor: @staticmethod def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): assert isinstance(sz, int), f"cannot create with symbolic size {sz}" - return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) # noqa: E501 + return Tensor(LazyBuffer.loadop(op, (sz,), dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) @staticmethod def empty(*shape, **kwargs): @@ -171,8 +170,9 @@ class Tensor: @staticmethod def full(shape:Tuple[sint, ...], fill_value, **kwargs): - # TODO: dtypes.default_type and dtypes.from_py - dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32) + # TODO: dtypes.from_py + dtype = kwargs.pop("dtype", + dtypes.default_float if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.default_int) return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod @@ -184,7 +184,7 @@ class Tensor: @staticmethod def arange(start, stop=None, step=1, **kwargs): if stop is None: stop, start = start, 0 - dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32) + dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step) @staticmethod @@ -202,7 +202,7 @@ class Tensor: def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor: # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform src = Tensor.rand(2, *shape, **kwargs) - return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) + return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float) @staticmethod def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(shape, low=low, high=high, dtype=dtypes.int32) @@ -212,7 +212,7 @@ class Tensor: @staticmethod def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor: - dtype = kwargs.pop("dtype", Tensor.default_type) + dtype = kwargs.pop("dtype", dtypes.default_float) return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low @staticmethod @@ -733,7 +733,7 @@ class Tensor: if 0 in self.shape: return self, self.full_like(y) if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype else: - y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.int32 if isinstance(y, int) else Tensor.default_type + y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.default_int if isinstance(y, int) else dtypes.default_float x = x.cast(y_dtype) y = Tensor(y, self.device, y_dtype, requires_grad=False)