diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 9c11a08b24..7f3469c7dc 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -590,7 +590,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.repeat": lambda x,*repeats: Tensor.repeat(x,*repeats).contiguous(), # not a view "aten._softmax": lambda self,dim,half_to_float: self.softmax(dim), "aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim), - "aten.random_": lambda self: Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype), + "aten.random_": lambda self: Tensor.randint(*self.shape, low=self.dtype.min, high=self.dtype.max, device=self.device, dtype=self.dtype), "aten.random_.from": lambda self, from_, to: Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype), "aten.uniform_": lambda self, low=0, high=1: Tensor.uniform(*self.shape, low=low, high=high, dtype=self.dtype), "aten.normal_": lambda self, mean=0, std=1: Tensor.normal(*self.shape, mean=mean, std=std, dtype=self.dtype), diff --git a/test/backend/test_dtype.py b/test/backend/test_dtype.py index 2501590d28..c5fe197163 100644 --- a/test/backend/test_dtype.py +++ b/test/backend/test_dtype.py @@ -101,7 +101,7 @@ class TestDType(unittest.TestCase): @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now") def test_uint_overflow(self): if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned") - v = dtypes.max(self.DTYPE) + v = self.DTYPE.max _test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2) _test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2) @@ -516,4 +516,3 @@ class TestOpsBFloat16(unittest.TestCase): if __name__ == '__main__': unittest.main() - diff --git a/test/backend/test_dtype_alu.py b/test/backend/test_dtype_alu.py index e5d3e0599e..9e4a51e88a 100644 --- a/test/backend/test_dtype_alu.py +++ b/test/backend/test_dtype_alu.py @@ -366,7 +366,7 @@ class TestDTypeALU(unittest.TestCase): @unittest.expectedFailure def test_unsafe_cast_float_to_int_failure(self): - val = float(dtypes.max(dtypes.int32) - 1) + val = float(dtypes.int32.max - 1) t1 = Tensor([val], dtype=dtypes.float32).cast(dtypes.int32) t2 = Tensor(val, dtype=dtypes.float32).cast(dtypes.int32) np.testing.assert_equal(t1.item(), t2.item()) diff --git a/test/backend/test_ops.py b/test/backend/test_ops.py index a507a0830f..28a4f2795b 100644 --- a/test/backend/test_ops.py +++ b/test/backend/test_ops.py @@ -479,9 +479,9 @@ class TestOps(unittest.TestCase): helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.]) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) helper_test_op(None, torch.maximum, Tensor.maximum, - vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True) + vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True) helper_test_op(None, torch.maximum, Tensor.maximum, - vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True) + vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True) @@ -496,9 +496,9 @@ class TestOps(unittest.TestCase): helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.]) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) helper_test_op(None, torch.minimum, Tensor.minimum, - vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.max(dtypes.int)], forward_only=True) + vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.max], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, - vals=[[-1234, 0, 1234, dtypes.max(dtypes.int), dtypes.min(dtypes.int)], dtypes.min(dtypes.int)], forward_only=True) + vals=[[-1234, 0, 1234, dtypes.int.max, dtypes.int.min], dtypes.int.min], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True) helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True) diff --git a/test/backend/test_quantize_onnx.py b/test/backend/test_quantize_onnx.py index 72d74b4251..af45d87900 100644 --- a/test/backend/test_quantize_onnx.py +++ b/test/backend/test_quantize_onnx.py @@ -204,13 +204,13 @@ class TestQuantizeOnnx(unittest.TestCase): W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize() tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8 out = (X.int().matmul(W.int())//1000) - if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype)) + if clip: out = out.clip(tg_dtype.min, tg_dtype.max) out = out.cast(tg_dtype) opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts sexec(out, opts, replace_src, run_count=1) tout = out.numpy() mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000) - if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype)) + if clip: mout = mout.clip(tg_dtype.min, tg_dtype.max) mout = mout.astype(xi) print(tout) print(mout) diff --git a/test/backend/test_renderer_failures.py b/test/backend/test_renderer_failures.py index 59928be324..3f9fa2afaa 100644 --- a/test/backend/test_renderer_failures.py +++ b/test/backend/test_renderer_failures.py @@ -62,7 +62,7 @@ class TestRendererFailures(unittest.TestCase): class TestCStyleFailures(unittest.TestCase): def test_inline_const_alu(self): # CPU doesn't use the max function - ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) + ret = _setup_and_test_alu(Ops.MAX, 1, UOp.const(dtypes.int, dtypes.int.min+1)) self.assertEqual(ret[0], 1) def _test_src_strip_paren(self, op: Ops, should_strip_paren:bool=True): diff --git a/test/null/test_dtype_spec.py b/test/null/test_dtype_spec.py index 49f157cbe3..197f9316eb 100644 --- a/test/null/test_dtype_spec.py +++ b/test/null/test_dtype_spec.py @@ -75,20 +75,20 @@ class TestHelpers(unittest.TestCase): def test_dtype_range(self): for dt in core_dtypes: if dtypes.is_float(dt): - np.testing.assert_equal(dtypes.min(dt), -math.inf) - np.testing.assert_equal(dtypes.max(dt), math.inf) + np.testing.assert_equal(dt.min, -math.inf) + np.testing.assert_equal(dt.max, math.inf) np.testing.assert_equal(dt.min, -math.inf) np.testing.assert_equal(dt.max, math.inf) elif dtypes.is_int(dt): info = np.iinfo(_to_np_dtype(dt)) - np.testing.assert_equal(dtypes.min(dt), info.min) - np.testing.assert_equal(dtypes.max(dt), info.max) + np.testing.assert_equal(dt.min, info.min) + np.testing.assert_equal(dt.max, info.max) np.testing.assert_equal(dt.min, info.min) np.testing.assert_equal(dt.max, info.max) else: assert dt == dtypes.bool, dt - np.testing.assert_equal(dtypes.min(dt), False) - np.testing.assert_equal(dtypes.max(dt), True) + np.testing.assert_equal(dt.min, False) + np.testing.assert_equal(dt.max, True) np.testing.assert_equal(dt.min, False) np.testing.assert_equal(dt.max, True) diff --git a/test/null/test_uop_vmin_vmax.py b/test/null/test_uop_vmin_vmax.py index ef889ee01b..9b307a813b 100644 --- a/test/null/test_uop_vmin_vmax.py +++ b/test/null/test_uop_vmin_vmax.py @@ -64,8 +64,8 @@ class TestVminVmaxProperties(unittest.TestCase): # negative mask: x & -1 could be anything since -1 has all bits set uop = x & -1 - self.assertEqual(uop.vmin, dtypes.min(dtypes.int32)) - self.assertEqual(uop.vmax, dtypes.max(dtypes.int32)) + self.assertEqual(uop.vmin, dtypes.int32.min) + self.assertEqual(uop.vmax, dtypes.int32.max) def test_vmin_vmax_multiplication_with_variable(self): # vmin and vmax for multiplication with a variable @@ -136,8 +136,8 @@ class TestVminVmaxProperties(unittest.TestCase): self.assertEqual(x_bool.vmin, False) self.assertEqual(x_bool.vmax, True) x_uint = x.cast(dtypes.uint) - self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint)) - self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint)) + self.assertEqual(x_uint.vmin, dtypes.uint.min) + self.assertEqual(x_uint.vmax, dtypes.uint.max) def test_vmin_vmax_invalid(self): i = UOp.invalid() diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index e20da4fb4d..4e37205c2e 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -79,10 +79,14 @@ class DType(metaclass=DTypeMetaClass): return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) def scalar(self) -> DType: return self._scalar if self._scalar is not None else self def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes") - @property - def min(self): return dtypes.min(self) - @property - def max(self): return dtypes.max(self) + @functools.cached_property + def min(self): + if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1) + return -float("inf") if dtypes.is_float(self) else False + @functools.cached_property + def max(self): + if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min + return float("inf") if dtypes.is_float(self) else True @dataclass(frozen=True, eq=False) class PtrDType(DType): @@ -172,16 +176,6 @@ class dtypes: # int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val) @staticmethod - @functools.cache - def min(dtype:DType): - if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().bitsize-1) - return -float("inf") if dtypes.is_float(dtype) else False - @staticmethod - @functools.cache - def max(dtype:DType): - if dtypes.is_int(dtype): return 2**(dtype.scalar().bitsize)-1+dtypes.min(dtype) - return float("inf") if dtypes.is_float(dtype) else True - @staticmethod def finfo(dtype:DType) -> tuple[int, int]: """(exponent, mantissa)""" if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index e9174b4f99..73c0feddbd 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -497,7 +497,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT o_ = [((i - 1) // s + 1) for i,s in zip(i_, s_)] return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad)) - def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtype.min, dtype.max).cast(dtype) def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0): if axis < 0: axis += x.ndim @@ -1209,7 +1209,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def DynamicQuantizeLinear(x: Tensor): # only support uint8 - qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8) + qmin, qmax = dtypes.uint8.min, dtypes.uint8.max scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin) zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8) y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 57b9e93bec..8c313bb7bb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2109,7 +2109,7 @@ class Tensor(OpMixin): x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None)) x_cummax, _ = x.cummax(-1) mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril() - ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax + ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), self.dtype.min).exp().sum(-1).log() + x_cummax return ret.transpose(-1, axis) def argmax(self, axis=None, keepdim=False) -> Tensor: @@ -2306,12 +2306,12 @@ class Tensor(OpMixin): axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) pads = self._resolve_pool_pads(padding, len(k_)) if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) - pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation) + pooled = self.pad(pads, value=self.dtype.min)._pool(k_, stride if stride is not None else k_, dilation) if not return_indices: return pooled.max(axis) spatial_sz = int(math.prod(spatial_shape := self.shape[-len(k_):])) idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape) m = pooled == pooled.max(axis, keepdim=True) - idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) + idx = m * idx.pad(pads, value=idx.dtype.min)._pool(k_, stride if stride is not None else k_, dilation) return pooled.max(axis), spatial_sz - idx.max(axis) def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None): @@ -2752,8 +2752,8 @@ class Tensor(OpMixin): def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b) if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)) if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1)) - if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) - if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m)) + if reduce == "amax": return mask.where(src, m := src.dtype.min).max(-1).maximum(self if include_self else _inv_mask(self, m)) + if reduce == "amin": return mask.where(src, m := src.dtype.max).min(-1).minimum(self if include_self else _inv_mask(self, m)) if reduce == "mean": count = mask.where(1, 0).sum(-1).add(1 if include_self else _inv_mask(1, 0)) return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)).div(count) @@ -2782,7 +2782,7 @@ class Tensor(OpMixin): # pad to power of 2 n_stages = (orig_len-1).bit_length() pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) - x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages) + x = x.pad(pads, value=x.dtype.min if descending else x.dtype.max).unflatten(dim, (2,)*n_stages) # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg for stage in range(1, n_stages+1): if stage != n_stages: diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index d9d724e534..4ef918f749 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -287,7 +287,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None: assert d>0, "Sign should have been taken out of divisor" vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max) m,s = magicgu(max(vmax, abs(vmin)), d) - if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype): + if m*vmin >= x.dtype.min and m*vmax <= x.dtype.max: return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0) # before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller if (largest_factor_of_two_in_d := (d & -d)) > 1: @@ -295,7 +295,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None: if dont_cast: return None # promo_lattice needs to return an unsigned type if the type is unsigned if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, device): - if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype): + if m*vmin >= next_dtype.min and m*vmax <= next_dtype.max: return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0) return None diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 23230d0abb..17bfcc3654 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1} # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) +def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt) # With True as the default, this matches the old symbolic behavior def resolve(x:UOp|bool, default:bool=True): @@ -843,8 +843,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.GEP: return self.src[0]._min_max # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,): - return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype)) - return dtypes.min(self.dtype), dtypes.max(self.dtype) + return max(self.dtype.min, self.src[0].vmin), min(self.src[0].vmax, self.dtype.max) + return self.dtype.min, self.dtype.max @functools.cached_property def _sym_fxn(self):