remove staticmethod dtypes.max/min (#15227)

always use x.dtype.max/min
This commit is contained in:
chenyu
2026-03-11 23:11:24 -04:00
committed by GitHub
parent 18dc77ccab
commit 842c978df3
13 changed files with 41 additions and 48 deletions

View File

@@ -590,7 +590,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.repeat": lambda x,*repeats: Tensor.repeat(x,*repeats).contiguous(), # not a view
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
"aten.random_": lambda self: Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype),
"aten.random_": lambda self: Tensor.randint(*self.shape, low=self.dtype.min, high=self.dtype.max, device=self.device, dtype=self.dtype),
"aten.random_.from": lambda self, from_, to: Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype),
"aten.uniform_": lambda self, low=0, high=1: Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype),
"aten.normal_": lambda self, mean=0, std=1: Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype),

View File

@@ -101,7 +101,7 @@ class TestDType(unittest.TestCase):
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
def test_uint_overflow(self):
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
v = dtypes.max(self.DTYPE)
v = self.DTYPE.max
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
@@ -516,4 +516,3 @@ class TestOpsBFloat16(unittest.TestCase):
if __name__ == '__main__':
unittest.main()

View File

@@ -366,7 +366,7 @@ class TestDTypeALU(unittest.TestCase):
@unittest.expectedFailure
def test_unsafe_cast_float_to_int_failure(self):
val = float(dtypes.max(dtypes.int32) - 1)
val = float(dtypes.int32.max - 1)
t1 = Tensor([val], dtype=dtypes.float32).cast(dtypes.int32)
t2 = Tensor(val, dtype=dtypes.float32).cast(dtypes.int32)
np.testing.assert_equal(t1.item(), t2.item())

View File

@@ -479,9 +479,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True)
@@ -496,9 +496,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.])
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]])
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True)
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum,
vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True)
vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True)
helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True)

View File

@@ -204,13 +204,13 @@ class TestQuantizeOnnx(unittest.TestCase):
W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize()
tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8
out = (X.int().matmul(W.int())//1000)
if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
if clip: out = out.clip(tg_dtype.min, tg_dtype.max)
out = out.cast(tg_dtype)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
sexec(out, opts, replace_src, run_count=1)
tout = out.numpy()
mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000)
if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
if clip: mout = mout.clip(tg_dtype.min, tg_dtype.max)
mout = mout.astype(xi)
print(tout)
print(mout)

View File

@@ -62,7 +62,7 @@ class TestRendererFailures(unittest.TestCase):
class TestCStyleFailures(unittest.TestCase):
def test_inline_const_alu(self):
# CPU doesn't use the max function
ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.int.min+1))
self.assertEqual(ret[0], 1)
def _test_src_strip_paren(self, op: Ops, should_strip_paren:bool=True):

View File

@@ -75,20 +75,20 @@ class TestHelpers(unittest.TestCase):
def test_dtype_range(self):
for dt in core_dtypes:
if dtypes.is_float(dt):
np.testing.assert_equal(dtypes.min(dt), -math.inf)
np.testing.assert_equal(dtypes.max(dt), math.inf)
np.testing.assert_equal(dt.min, -math.inf)
np.testing.assert_equal(dt.max, math.inf)
np.testing.assert_equal(dt.min, -math.inf)
np.testing.assert_equal(dt.max, math.inf)
elif dtypes.is_int(dt):
info = np.iinfo(_to_np_dtype(dt))
np.testing.assert_equal(dtypes.min(dt), info.min)
np.testing.assert_equal(dtypes.max(dt), info.max)
np.testing.assert_equal(dt.min, info.min)
np.testing.assert_equal(dt.max, info.max)
np.testing.assert_equal(dt.min, info.min)
np.testing.assert_equal(dt.max, info.max)
else:
assert dt == dtypes.bool, dt
np.testing.assert_equal(dtypes.min(dt), False)
np.testing.assert_equal(dtypes.max(dt), True)
np.testing.assert_equal(dt.min, False)
np.testing.assert_equal(dt.max, True)
np.testing.assert_equal(dt.min, False)
np.testing.assert_equal(dt.max, True)

View File

@@ -64,8 +64,8 @@ class TestVminVmaxProperties(unittest.TestCase):
# negative mask: x & -1 could be anything since -1 has all bits set
uop = x & -1
self.assertEqual(uop.vmin, dtypes.min(dtypes.int32))
self.assertEqual(uop.vmax, dtypes.max(dtypes.int32))
self.assertEqual(uop.vmin, dtypes.int32.min)
self.assertEqual(uop.vmax, dtypes.int32.max)
def test_vmin_vmax_multiplication_with_variable(self):
# vmin and vmax for multiplication with a variable
@@ -136,8 +136,8 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(x_bool.vmin, False)
self.assertEqual(x_bool.vmax, True)
x_uint = x.cast(dtypes.uint)
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
self.assertEqual(x_uint.vmin, dtypes.uint.min)
self.assertEqual(x_uint.vmax, dtypes.uint.max)
def test_vmin_vmax_invalid(self):
i = UOp.invalid()

View File

@@ -79,10 +79,14 @@ class DType(metaclass=DTypeMetaClass):
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
@property
def min(self): return dtypes.min(self)
@property
def max(self): return dtypes.max(self)
@functools.cached_property
def min(self):
if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1)
return -float("inf") if dtypes.is_float(self) else False
@functools.cached_property
def max(self):
if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
return float("inf") if dtypes.is_float(self) else True
@dataclass(frozen=True, eq=False)
class PtrDType(DType):
@@ -172,16 +176,6 @@ class dtypes:
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val)
@staticmethod
@functools.cache
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
@functools.cache
def max(dtype:DType):
if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype)
return float("inf") if dtypes.is_float(dtype) else True
@staticmethod
def finfo(dtype:DType) -> tuple[int, int]:
"""(exponent, mantissa)"""
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")

View File

@@ -497,7 +497,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
o_ = [((i - 1) // s + 1) for i,s in zip(i_, s_)]
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtype.min, dtype.max).cast(dtype)
def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0):
if axis < 0: axis += x.ndim
@@ -1209,7 +1209,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def DynamicQuantizeLinear(x: Tensor):
# only support uint8
qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8)
qmin, qmax = dtypes.uint8.min, dtypes.uint8.max
scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin)
zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8)
y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8)

View File

@@ -2109,7 +2109,7 @@ class Tensor(OpMixin):
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
x_cummax, _ = x.cummax(-1)
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), self.dtype.min).exp().sum(-1).log() + x_cummax
return ret.transpose(-1, axis)
def argmax(self, axis=None, keepdim=False) -> Tensor:
@@ -2306,12 +2306,12 @@ class Tensor(OpMixin):
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
pads = self._resolve_pool_pads(padding, len(k_))
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
pooled = self.pad(pads, value=self.dtype.min)._pool(k_, stride if stride is not None else k_, dilation)
if not return_indices: return pooled.max(axis)
spatial_sz = int(math.prod(spatial_shape := self.shape[-len(k_):]))
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
m = pooled == pooled.max(axis, keepdim=True)
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
idx = m * idx.pad(pads, value=idx.dtype.min)._pool(k_, stride if stride is not None else k_, dilation)
return pooled.max(axis), spatial_sz - idx.max(axis)
def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
@@ -2752,8 +2752,8 @@ class Tensor(OpMixin):
def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b)
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "amax": return mask.where(src, m := src.dtype.min).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := src.dtype.max).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "mean":
count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count)
@@ -2782,7 +2782,7 @@ class Tensor(OpMixin):
# pad to power of 2
n_stages = (orig_len-1).bit_length()
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages)
x = x.pad(pads, value=x.dtype.min if descending else x.dtype.max).unflatten(dim, (2,)*n_stages)
# https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg
for stage in range(1, n_stages+1):
if stage != n_stages:

View File

@@ -287,7 +287,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
assert d>0, "Sign should have been taken out of divisor"
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
m,s = magicgu(max(vmax, abs(vmin)), d)
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
if m*vmin >= x.dtype.min and m*vmax <= x.dtype.max:
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
if (largest_factor_of_two_in_d := (d & -d)) > 1:
@@ -295,7 +295,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
if dont_cast: return None
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, device):
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
if m*vmin >= next_dtype.min and m*vmax <= next_dtype.max:
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
return None

View File

@@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt)
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
@@ -843,8 +843,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
return dtypes.min(self.dtype), dtypes.max(self.dtype)
return max(self.dtype.min, self.src[0].vmin), min(self.src[0].vmax, self.dtype.max)
return self.dtype.min, self.dtype.max
@functools.cached_property
def _sym_fxn(self):