dtypes.as_const -> DType.const (#15337)

does not need to be a staticmethod
This commit is contained in:
chenyu
2026-03-18 00:48:41 -04:00
committed by GitHub
parent 864d3917d5
commit ac7a348d06
6 changed files with 24 additions and 25 deletions

View File

@@ -77,15 +77,15 @@ class TestUOps(unittest.TestCase):
def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
a = dtypes.as_const(a, dts[0])
a = dts[0].const(a)
self._equal(f([a], op, dts), fxn(a))
def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
a = dtypes.as_const(a, dts[0])
b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
a = dts[0].const(a)
b = dts[1].const(abs(b) if no_b_neg else b)
self._equal(f([a,b], op, dts), fxn(a,b))
def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
@@ -93,9 +93,9 @@ class TestUOps(unittest.TestCase):
for a in [-2.0, 0, 1]:
for b in [-3.0, 3.0]:
for c in [-4.0, 4.0]:
a = dtypes.as_const(a, dts[0])
b = dtypes.as_const(b, dts[1])
c = dtypes.as_const(c, dts[2])
a = dts[0].const(a)
b = dts[1].const(b)
c = dts[2].const(c)
self._equal(f([a,b,c], op, dts), fxn(a,b,c))
class TestFloatUOps(TestUOps):
@@ -117,7 +117,7 @@ class TestFloatUOps(TestUOps):
def test_cmpne_nan(self): # NaN != x for any x (IEEE 754)
for a, b in [(math.nan, 1.0), (1.0, math.nan), (math.nan, math.nan)]:
self.assertTrue(_test_single_value(
[dtypes.as_const(a, dtypes.float32), dtypes.as_const(b, dtypes.float32)],
[dtypes.float32.const(a), dtypes.float32.const(b)],
Ops.CMPNE, (dtypes.float32, dtypes.float32)))
# MOD isn't tested on floats

View File

@@ -87,6 +87,15 @@ class DType(metaclass=DTypeMetaClass):
def max(self):
if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
return float("inf") if dtypes.is_float(self) else True
def const(self, val: tuple[ConstType, ...]|ConstType):
if isinstance(val, tuple):
assert len(val) == self.count, f"mismatch {val} {self}"
return tuple(map(self.const, val))
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(val, float) and math.isnan(val): val = math.nan
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val)
@dataclass(frozen=True, eq=False)
class PtrDType(DType):
@@ -165,16 +174,6 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(val, float) and math.isnan(val): val = math.nan
# int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val)
@staticmethod
def finfo(dtype:DType) -> tuple[int, int]:
"""(exponent, mantissa)"""
if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")

View File

@@ -112,7 +112,7 @@ class PythonProgram:
elif uop is Ops.VECTORIZE: values[i] = src_values
elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.CAST:
values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]]
values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \

View File

@@ -61,7 +61,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp:
ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape)
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)])
data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
return ret
@@ -197,7 +197,7 @@ class Tensor(OpMixin):
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs)
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def const_like(self, b:ConstType) -> Tensor: return Tensor(dtypes.as_const(b, self.dtype), self.device, self.dtype, requires_grad=False)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.dtype.const(b), self.device, self.dtype, requires_grad=False)
def requires_grad_(self, requires_grad=True) -> Tensor:
# make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps
@@ -2975,7 +2975,7 @@ class Tensor(OpMixin):
y_dtype = x.dtype
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)

View File

@@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt)
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
@@ -471,7 +471,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
arg=dtypes.as_const(b, dtype),
arg=dtype.const(b),
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
@staticmethod

View File

@@ -36,7 +36,7 @@ shared_spec = PatternMatcher([
(UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything
# CONST/DEFINE_VAR are everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
@@ -239,7 +239,7 @@ program_spec = PatternMatcher([
(UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and
type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
type(x.arg) is type(x.dtype.const(x.arg))),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),