mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -590,7 +590,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.repeat": lambda x,*repeats: Tensor.repeat(x,*repeats).contiguous(), # not a view
|
||||
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
|
||||
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
|
||||
"aten.random_": lambda self: Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype),
|
||||
"aten.random_": lambda self: Tensor.randint(*self.shape, low=self.dtype.min, high=self.dtype.max, device=self.device, dtype=self.dtype),
|
||||
"aten.random_.from": lambda self, from_, to: Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype),
|
||||
"aten.uniform_": lambda self, low=0, high=1: Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype),
|
||||
"aten.normal_": lambda self, mean=0, std=1: Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype),
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestDType(unittest.TestCase):
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
|
||||
def test_uint_overflow(self):
|
||||
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
|
||||
v = dtypes.max(self.DTYPE)
|
||||
v = self.DTYPE.max
|
||||
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
|
||||
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
|
||||
|
||||
@@ -516,4 +516,3 @@ class TestOpsBFloat16(unittest.TestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -366,7 +366,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_unsafe_cast_float_to_int_failure(self):
|
||||
val = float(dtypes.max(dtypes.int32) - 1)
|
||||
val = float(dtypes.int32.max - 1)
|
||||
t1 = Tensor([val], dtype=dtypes.float32).cast(dtypes.int32)
|
||||
t2 = Tensor(val, dtype=dtypes.float32).cast(dtypes.int32)
|
||||
np.testing.assert_equal(t1.item(), t2.item())
|
||||
|
||||
@@ -479,9 +479,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum,
|
||||
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
|
||||
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum,
|
||||
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
|
||||
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)
|
||||
|
||||
@@ -496,9 +496,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum,
|
||||
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
|
||||
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True)
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum,
|
||||
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
|
||||
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True)
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
|
||||
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)
|
||||
|
||||
|
||||
@@ -204,13 +204,13 @@ class TestQuantizeOnnx(unittest.TestCase):
|
||||
W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize()
|
||||
tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8
|
||||
out = (X.int().matmul(W.int())//1000)
|
||||
if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
|
||||
if clip: out = out.clip(tg_dtype.min, tg_dtype.max)
|
||||
out = out.cast(tg_dtype)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
|
||||
sexec(out, opts, replace_src, run_count=1)
|
||||
tout = out.numpy()
|
||||
mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000)
|
||||
if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
|
||||
if clip: mout = mout.clip(tg_dtype.min, tg_dtype.max)
|
||||
mout = mout.astype(xi)
|
||||
print(tout)
|
||||
print(mout)
|
||||
|
||||
@@ -62,7 +62,7 @@ class TestRendererFailures(unittest.TestCase):
|
||||
class TestCStyleFailures(unittest.TestCase):
|
||||
def test_inline_const_alu(self):
|
||||
# CPU doesn't use the max function
|
||||
ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.int.min+1))
|
||||
self.assertEqual(ret[0], 1)
|
||||
|
||||
def _test_src_strip_paren(self, op: Ops, should_strip_paren:bool=True):
|
||||
|
||||
@@ -75,20 +75,20 @@ class TestHelpers(unittest.TestCase):
|
||||
def test_dtype_range(self):
|
||||
for dt in core_dtypes:
|
||||
if dtypes.is_float(dt):
|
||||
np.testing.assert_equal(dtypes.min(dt), -math.inf)
|
||||
np.testing.assert_equal(dtypes.max(dt), math.inf)
|
||||
np.testing.assert_equal(dt.min, -math.inf)
|
||||
np.testing.assert_equal(dt.max, math.inf)
|
||||
np.testing.assert_equal(dt.min, -math.inf)
|
||||
np.testing.assert_equal(dt.max, math.inf)
|
||||
elif dtypes.is_int(dt):
|
||||
info = np.iinfo(_to_np_dtype(dt))
|
||||
np.testing.assert_equal(dtypes.min(dt), info.min)
|
||||
np.testing.assert_equal(dtypes.max(dt), info.max)
|
||||
np.testing.assert_equal(dt.min, info.min)
|
||||
np.testing.assert_equal(dt.max, info.max)
|
||||
np.testing.assert_equal(dt.min, info.min)
|
||||
np.testing.assert_equal(dt.max, info.max)
|
||||
else:
|
||||
assert dt == dtypes.bool, dt
|
||||
np.testing.assert_equal(dtypes.min(dt), False)
|
||||
np.testing.assert_equal(dtypes.max(dt), True)
|
||||
np.testing.assert_equal(dt.min, False)
|
||||
np.testing.assert_equal(dt.max, True)
|
||||
np.testing.assert_equal(dt.min, False)
|
||||
np.testing.assert_equal(dt.max, True)
|
||||
|
||||
|
||||
@@ -64,8 +64,8 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
|
||||
# negative mask: x & -1 could be anything since -1 has all bits set
|
||||
uop = x & -1
|
||||
self.assertEqual(uop.vmin, dtypes.min(dtypes.int32))
|
||||
self.assertEqual(uop.vmax, dtypes.max(dtypes.int32))
|
||||
self.assertEqual(uop.vmin, dtypes.int32.min)
|
||||
self.assertEqual(uop.vmax, dtypes.int32.max)
|
||||
|
||||
def test_vmin_vmax_multiplication_with_variable(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
@@ -136,8 +136,8 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(x_bool.vmin, False)
|
||||
self.assertEqual(x_bool.vmax, True)
|
||||
x_uint = x.cast(dtypes.uint)
|
||||
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
|
||||
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
|
||||
self.assertEqual(x_uint.vmin, dtypes.uint.min)
|
||||
self.assertEqual(x_uint.vmax, dtypes.uint.max)
|
||||
|
||||
def test_vmin_vmax_invalid(self):
|
||||
i = UOp.invalid()
|
||||
|
||||
@@ -79,10 +79,14 @@ class DType(metaclass=DTypeMetaClass):
|
||||
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
|
||||
@property
|
||||
def min(self): return dtypes.min(self)
|
||||
@property
|
||||
def max(self): return dtypes.max(self)
|
||||
@functools.cached_property
|
||||
def min(self):
|
||||
if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1)
|
||||
return -float("inf") if dtypes.is_float(self) else False
|
||||
@functools.cached_property
|
||||
def max(self):
|
||||
if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
|
||||
return float("inf") if dtypes.is_float(self) else True
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class PtrDType(DType):
|
||||
@@ -172,16 +176,6 @@ class dtypes:
|
||||
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
|
||||
return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1)
|
||||
return -float("inf") if dtypes.is_float(dtype) else False
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def max(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype)
|
||||
return float("inf") if dtypes.is_float(dtype) else True
|
||||
@staticmethod
|
||||
def finfo(dtype:DType) -> tuple[int, int]:
|
||||
"""(exponent, mantissa)"""
|
||||
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
|
||||
|
||||
@@ -497,7 +497,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
o_ = [((i - 1) // s + 1) for i,s in zip(i_, s_)]
|
||||
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
|
||||
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtype.min, dtype.max).cast(dtype)
|
||||
|
||||
def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0):
|
||||
if axis < 0: axis += x.ndim
|
||||
@@ -1209,7 +1209,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
|
||||
def DynamicQuantizeLinear(x: Tensor):
|
||||
# only support uint8
|
||||
qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8)
|
||||
qmin, qmax = dtypes.uint8.min, dtypes.uint8.max
|
||||
scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin)
|
||||
zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8)
|
||||
y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8)
|
||||
|
||||
@@ -2109,7 +2109,7 @@ class Tensor(OpMixin):
|
||||
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
|
||||
x_cummax, _ = x.cummax(-1)
|
||||
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
|
||||
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
|
||||
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), self.dtype.min).exp().sum(-1).log() + x_cummax
|
||||
return ret.transpose(-1, axis)
|
||||
|
||||
def argmax(self, axis=None, keepdim=False) -> Tensor:
|
||||
@@ -2306,12 +2306,12 @@ class Tensor(OpMixin):
|
||||
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
||||
pads = self._resolve_pool_pads(padding, len(k_))
|
||||
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
||||
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
pooled = self.pad(pads, value=self.dtype.min)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
if not return_indices: return pooled.max(axis)
|
||||
spatial_sz = int(math.prod(spatial_shape := self.shape[-len(k_):]))
|
||||
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
|
||||
m = pooled == pooled.max(axis, keepdim=True)
|
||||
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
idx = m * idx.pad(pads, value=idx.dtype.min)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pooled.max(axis), spatial_sz - idx.max(axis)
|
||||
|
||||
def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
|
||||
@@ -2752,8 +2752,8 @@ class Tensor(OpMixin):
|
||||
def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amax": return mask.where(src, m := src.dtype.min).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "amin": return mask.where(src, m := src.dtype.max).min(-1).minimum(self if include_self else _inv_mask(self, m))
|
||||
if reduce == "mean":
|
||||
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
|
||||
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
@@ -2782,7 +2782,7 @@ class Tensor(OpMixin):
|
||||
# pad to power of 2
|
||||
n_stages = (orig_len-1).bit_length()
|
||||
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
|
||||
x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages)
|
||||
x = x.pad(pads, value=x.dtype.min if descending else x.dtype.max).unflatten(dim, (2,)*n_stages)
|
||||
# https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg
|
||||
for stage in range(1, n_stages+1):
|
||||
if stage != n_stages:
|
||||
|
||||
@@ -287,7 +287,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
assert d>0, "Sign should have been taken out of divisor"
|
||||
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
|
||||
m,s = magicgu(max(vmax, abs(vmin)), d)
|
||||
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
|
||||
if m*vmin >= x.dtype.min and m*vmax <= x.dtype.max:
|
||||
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
|
||||
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
|
||||
if (largest_factor_of_two_in_d := (d & -d)) > 1:
|
||||
@@ -295,7 +295,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
|
||||
if dont_cast: return None
|
||||
# promo_lattice needs to return an unsigned type if the type is unsigned
|
||||
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, device):
|
||||
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
|
||||
if m*vmin >= next_dtype.min and m*vmax <= next_dtype.max:
|
||||
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
|
||||
return None
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt)
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x:UOp|bool, default:bool=True):
|
||||
@@ -843,8 +843,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
|
||||
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
return max(self.dtype.min, self.src[0].vmin), min(self.src[0].vmax, self.dtype.max)
|
||||
return self.dtype.min, self.dtype.max
|
||||
|
||||
@functools.cached_property
|
||||
def _sym_fxn(self):
|
||||
|
||||
Reference in New Issue
Block a user