mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
dtypes.default_float and dtypes.default_int (#2824)
This commit is contained in:
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
# https://siboehm.com/articles/22/CUDA-MMM
|
||||
import random, time
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
|
||||
from typing import Any, Dict, Optional, SupportsIndex
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
@@ -27,10 +27,10 @@ from tinygrad.jit import TinyJit
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
|
||||
if getenv("HALF", 0):
|
||||
Tensor.default_type = dtypes.float16
|
||||
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
|
||||
dtypes.default_float = dtypes.float16
|
||||
np_dtype = np.float16
|
||||
else:
|
||||
Tensor.default_type = dtypes.float32
|
||||
dtypes.default_float = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
@@ -52,13 +52,13 @@ class ConvGroup:
|
||||
x = x.max_pool2d(2)
|
||||
x = x.float()
|
||||
x = self.norm1(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
residual = x
|
||||
x = self.conv2(x)
|
||||
x = x.float()
|
||||
x = self.norm2(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
|
||||
return x + residual
|
||||
@@ -277,7 +277,8 @@ def train_cifar():
|
||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
||||
|
||||
# Convert data and labels to the default dtype
|
||||
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
|
||||
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
|
||||
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
|
||||
@@ -482,8 +482,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
||||
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
|
||||
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
|
||||
scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
|
||||
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
|
||||
x_out = Tensor.arange(output_shape[-1]).cast(dtypes.default_float)
|
||||
y_out = Tensor.arange(output_shape[-2]).cast(dtypes.default_float)
|
||||
if mode == "nearest":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
|
||||
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
||||
|
||||
6
test/external/external_test_jit_on_models.py
vendored
6
test/external/external_test_jit_on_models.py
vendored
@@ -15,8 +15,8 @@ def helper_test_jitted_correctness(gen, train, train_jit):
|
||||
|
||||
class TestJittedModels(unittest.TestCase):
|
||||
def test_jitted_tiny_llama(self):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
old_float = dtypes.default_float
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
@@ -26,7 +26,7 @@ class TestJittedModels(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test_jit(t): return model(t, 0).realize()
|
||||
helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit)
|
||||
Tensor.default_type = old_type
|
||||
dtypes.default_float = old_float
|
||||
|
||||
@unittest.skipUnless(not CI, "huge for CI")
|
||||
def test_jitted_stable_diffusion(self):
|
||||
|
||||
@@ -40,11 +40,11 @@ class TestRealWorld(unittest.TestCase):
|
||||
gc.collect()
|
||||
global global_mem_used
|
||||
global_mem_used = GlobalCounters.mem_used
|
||||
self.old_type = Tensor.default_type
|
||||
self.old_float = dtypes.default_float
|
||||
np.random.seed(2002)
|
||||
|
||||
def tearDown(self):
|
||||
Tensor.default_type = self.old_type
|
||||
dtypes.default_float = self.old_float
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(CI, "too big for CI")
|
||||
@@ -68,7 +68,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
|
||||
def test_llama(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
|
||||
@@ -80,7 +80,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000}
|
||||
model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]))
|
||||
|
||||
@@ -224,56 +224,78 @@ class TestHelpers(unittest.TestCase):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def test_creation(self):
|
||||
assert Tensor(True).dtype == dtypes.bool
|
||||
assert Tensor(2).dtype == dtypes.int
|
||||
assert Tensor(2.34).dtype == Tensor.default_type
|
||||
assert Tensor([]).dtype == Tensor.default_type
|
||||
assert Tensor([1]).dtype == dtypes.int
|
||||
assert Tensor([1.1]).dtype == Tensor.default_type
|
||||
assert Tensor([0, 1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
def test_const_full(self):
|
||||
assert Tensor.ones([2,3]).dtype == Tensor.default_type
|
||||
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
|
||||
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
|
||||
assert Tensor.full([2,3], 3).dtype == dtypes.int
|
||||
def test_set_dtype_default(self):
|
||||
dtypes.default_int = dtypes.int16
|
||||
assert dtypes.default_int == dtypes.int16
|
||||
dtypes.default_int = dtypes.int64
|
||||
assert dtypes.default_int == dtypes.int64
|
||||
dtypes.default_int = dtypes.int32
|
||||
assert dtypes.default_int == dtypes.int32
|
||||
dtypes.default_float = dtypes.float16
|
||||
assert dtypes.default_float == dtypes.float16
|
||||
dtypes.default_float = dtypes.float64
|
||||
assert dtypes.default_float == dtypes.float64
|
||||
|
||||
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
assert Tensor(True).dtype == dtypes.bool
|
||||
assert Tensor(2).dtype == dtypes.default_int
|
||||
assert Tensor(2.34).dtype == dtypes.default_float
|
||||
assert Tensor([]).dtype == dtypes.default_float
|
||||
assert Tensor([1]).dtype == dtypes.default_int
|
||||
assert Tensor([1.1]).dtype == dtypes.default_float
|
||||
assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
|
||||
|
||||
assert Tensor.eye(0).dtype == dtypes.default_float
|
||||
assert Tensor.eye(3).dtype == dtypes.default_float
|
||||
assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
|
||||
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_full(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
assert Tensor.ones([2,3]).dtype == dtypes.default_float
|
||||
assert Tensor.zeros([2,3]).dtype == dtypes.default_float
|
||||
assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float
|
||||
assert Tensor.full([2,3], 3).dtype == dtypes.default_int
|
||||
assert Tensor.full([2,3], True).dtype == dtypes.bool
|
||||
|
||||
assert Tensor.zeros(3, 3).dtype == dtypes.default_float
|
||||
assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
assert Tensor.ones(3, 3).dtype == dtypes.default_float
|
||||
assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
assert Tensor.full((3, 3), 3).dtype == dtypes.default_int
|
||||
assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float
|
||||
assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_reduce_0d_default(self):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int # requires reduceop acc fix
|
||||
|
||||
def test_arange(self):
|
||||
assert Tensor.arange(5).dtype == dtypes.int32
|
||||
assert Tensor.arange(5.0).dtype == Tensor.default_type
|
||||
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
def test_arange(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
assert Tensor.arange(5).dtype == dtypes.default_int
|
||||
assert Tensor.arange(5.0).dtype == dtypes.default_float
|
||||
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
|
||||
assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type
|
||||
|
||||
def test_zeros(self):
|
||||
assert Tensor.zeros(3, 3).dtype == Tensor.default_type
|
||||
assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_ones(self):
|
||||
assert Tensor.ones(3, 3).dtype == Tensor.default_type
|
||||
assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_full(self):
|
||||
assert Tensor.full((3, 3), 3).dtype == dtypes.int
|
||||
assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type
|
||||
assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_eye(self):
|
||||
assert Tensor.eye(0).dtype == Tensor.default_type
|
||||
assert Tensor.eye(3).dtype == Tensor.default_type
|
||||
assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
|
||||
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_types if dtypes.is_float(dt)]
|
||||
@@ -331,10 +353,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_broadcast_float(self):
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == Tensor.default_type
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == Tensor.default_type
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == Tensor.default_type
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == Tensor.default_type
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
|
||||
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from examples.beautiful_mnist import Model as MNIST
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
|
||||
@@ -23,8 +22,8 @@ class TestHIPCompilationRDNA(unittest.TestCase):
|
||||
output.numpy()
|
||||
|
||||
def test_compile_hip_speedyresnet_hf(self):
|
||||
old_default_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
old_default_float = dtypes.default_float
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
W = Tensor.rand(12,3,2,2)
|
||||
model = SpeedyResNet(W)
|
||||
@@ -33,7 +32,7 @@ class TestHIPCompilationRDNA(unittest.TestCase):
|
||||
output = model(input)
|
||||
output.numpy()
|
||||
|
||||
Tensor.default_type = old_default_type
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -252,7 +252,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
# empty tensor defaults
|
||||
for arr in ([], [[[]]], [[],[]]):
|
||||
t = Tensor(arr)
|
||||
assert t.dtype == Tensor.default_type
|
||||
assert t.dtype == dtypes.default_float
|
||||
np.testing.assert_allclose(t.numpy(), np.array(arr))
|
||||
|
||||
# mixture of bool and int
|
||||
@@ -264,7 +264,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
# mixture of bool, int and float
|
||||
for arr in ([[True,True],[3.,True]], [[0,1],[3.,4]], [[[0],[1]],[[3.],[4]]], [[[True],[1]],[[3.],[4]]]):
|
||||
t = Tensor(arr)
|
||||
assert t.dtype == Tensor.default_type
|
||||
assert t.dtype == dtypes.default_float
|
||||
np.testing.assert_allclose(t.numpy(), np.array(arr))
|
||||
|
||||
def test_tensor_list_shapes(self):
|
||||
|
||||
@@ -181,6 +181,9 @@ class dtypes:
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
# we don't support weak type and complex type
|
||||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
|
||||
@@ -198,7 +201,8 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
|
||||
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
||||
class GlobalCounters:
|
||||
|
||||
@@ -46,7 +46,6 @@ class Tensor:
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
@@ -61,14 +60,14 @@ class Tensor:
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data)
|
||||
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.int32, device, data)
|
||||
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_int, device, data)
|
||||
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_float, device, data)
|
||||
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
|
||||
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np))
|
||||
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np))
|
||||
elif isinstance(data, list):
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
||||
elif d and all_int(d): dtype = dtype or dtypes.int32
|
||||
else: dtype = dtype or Tensor.default_type
|
||||
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
||||
else: dtype = dtype or dtypes.default_float
|
||||
# NOTE: cast at the end for the types that do not have a numpy dtype
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
|
||||
elif isinstance(data, np.ndarray):
|
||||
@@ -153,7 +152,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
||||
assert isinstance(sz, int), f"cannot create with symbolic size {sz}"
|
||||
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) # noqa: E501
|
||||
return Tensor(LazyBuffer.loadop(op, (sz,), dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs):
|
||||
@@ -171,8 +170,9 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
|
||||
# TODO: dtypes.default_type and dtypes.from_py
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32)
|
||||
# TODO: dtypes.from_py
|
||||
dtype = kwargs.pop("dtype",
|
||||
dtypes.default_float if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.default_int)
|
||||
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
@@ -184,7 +184,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32)
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
@@ -202,7 +202,7 @@ class Tensor:
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
|
||||
|
||||
@staticmethod
|
||||
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(shape, low=low, high=high, dtype=dtypes.int32)
|
||||
@@ -212,7 +212,7 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type)
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float)
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
||||
|
||||
@staticmethod
|
||||
@@ -733,7 +733,7 @@ class Tensor:
|
||||
if 0 in self.shape: return self, self.full_like(y)
|
||||
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
else:
|
||||
y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.int32 if isinstance(y, int) else Tensor.default_type
|
||||
y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.default_int if isinstance(y, int) else dtypes.default_float
|
||||
x = x.cast(y_dtype)
|
||||
y = Tensor(y, self.device, y_dtype, requires_grad=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user