diff --git a/test/backend/test_uops.py b/test/backend/test_uops.py index c43f229346..c9dea0b73c 100644 --- a/test/backend/test_uops.py +++ b/test/backend/test_uops.py @@ -77,15 +77,15 @@ class TestUOps(unittest.TestCase): def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: - a = dtypes.as_const(a, dts[0]) + a = dts[0].const(a) self._equal(f([a], op, dts), fxn(a)) def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): - a = dtypes.as_const(a, dts[0]) - b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1]) + a = dts[0].const(a) + b = dts[1].const(abs(b) if no_b_neg else b) self._equal(f([a,b], op, dts), fxn(a,b)) def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3): @@ -93,9 +93,9 @@ class TestUOps(unittest.TestCase): for a in [-2.0, 0, 1]: for b in [-3.0, 3.0]: for c in [-4.0, 4.0]: - a = dtypes.as_const(a, dts[0]) - b = dtypes.as_const(b, dts[1]) - c = dtypes.as_const(c, dts[2]) + a = dts[0].const(a) + b = dts[1].const(b) + c = dts[2].const(c) self._equal(f([a,b,c], op, dts), fxn(a,b,c)) class TestFloatUOps(TestUOps): @@ -117,7 +117,7 @@ class TestFloatUOps(TestUOps): def test_cmpne_nan(self): # NaN != x for any x (IEEE 754) for a, b in [(math.nan, 1.0), (1.0, math.nan), (math.nan, math.nan)]: self.assertTrue(_test_single_value( - [dtypes.as_const(a, dtypes.float32), dtypes.as_const(b, dtypes.float32)], + [dtypes.float32.const(a), dtypes.float32.const(b)], Ops.CMPNE, (dtypes.float32, dtypes.float32))) # MOD isn't tested on floats diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 392a7dd92e..5dce3c3e5b 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -87,6 +87,15 @@ class DType(metaclass=DTypeMetaClass): def max(self): if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min return float("inf") if dtypes.is_float(self) else True + def const(self, val: tuple[ConstType, ...]|ConstType): + if isinstance(val, tuple): + assert len(val) == self.count, f"mismatch {val} {self}" + return tuple(map(self.const, val)) + if isinstance(val, InvalidType): return val + # NOTE: float('nan') != float('nan'), so we canonicalize here + if isinstance(val, float) and math.isnan(val): val = math.nan + # int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache + return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val) @dataclass(frozen=True, eq=False) class PtrDType(DType): @@ -165,16 +174,6 @@ class dtypes: if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod - def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType): - if isinstance(val, tuple): - assert len(val) == dtype.count, f"mismatch {val} {dtype}" - return tuple(dtypes.as_const(x, dtype) for x in val) - if isinstance(val, InvalidType): return val - # NOTE: float('nan') != float('nan'), so we canonicalize here - if isinstance(val, float) and math.isnan(val): val = math.nan - # int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache - return ConstFloat(float(val)) if dtypes.is_float(dtype) else bool(val) if dtypes.is_bool(dtype) else int(val) - @staticmethod def finfo(dtype:DType) -> tuple[int, int]: """(exponent, mantissa)""" if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 484b862bb9..ce8225cc49 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -112,7 +112,7 @@ class PythonProgram: elif uop is Ops.VECTORIZE: values[i] = src_values elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]] elif uop is Ops.CAST: - values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]] + values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]] elif uop is Ops.LOAD: if dtype.count > 1: values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bfafeb08ed..800a505c7a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -61,7 +61,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp: ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape) assert dtype.fmt is not None, f"{dtype=} has None fmt" truncate_function = truncate[dtype] - data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)]) + data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtype.const(xi)) for xi in fully_flatten(x)]) # fake realize ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data))) return ret @@ -197,7 +197,7 @@ class Tensor(OpMixin): lhs,rhs = self._broadcasted(x, reverse) return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs) def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) - def const_like(self, b:ConstType) -> Tensor: return Tensor(dtypes.as_const(b, self.dtype), self.device, self.dtype, requires_grad=False) + def const_like(self, b:ConstType) -> Tensor: return Tensor(self.dtype.const(b), self.device, self.dtype, requires_grad=False) def requires_grad_(self, requires_grad=True) -> Tensor: # make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps @@ -2975,7 +2975,7 @@ class Tensor(OpMixin): y_dtype = x.dtype elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y) if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device) - else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False) + else: y = Tensor(y_dtype.const(y), x.device, y_dtype, requires_grad=False) if match_dtype and x.dtype != y.dtype: output_dtype = least_upper_dtype(x.dtype, y.dtype) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 52c2b44440..f7fe6ef127 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1} # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op], dt) +def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op]) # With True as the default, this matches the old symbolic behavior def resolve(x:UOp|bool, default:bool=True): @@ -471,7 +471,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert len(b) > 0, "can't create const from empty tuple" b = b[0] # doesn't have to be a VCONST if they are all the same ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, - arg=dtypes.as_const(b, dtype), + arg=dtype.const(b), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret @staticmethod diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 2a219925a3..2f956bc222 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -36,7 +36,7 @@ shared_spec = PatternMatcher([ (UPat(Ops.SINK, dtypes.void), lambda: True), # NOTE: for testing, we let sinks be anything # CONST/DEFINE_VAR are everywhere - (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), # ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE @@ -239,7 +239,7 @@ program_spec = PatternMatcher([ (UPat(GroupOp.All-{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR, Ops.VCONST, Ops.VECTORIZE}, dtype=dtypes.index), lambda: False), (UPat(Ops.CONST, arg=Invalid), lambda: False), (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.arg) and len(x.arg)==x.dtype.vcount>1 and - type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + type(x.arg) is type(x.dtype.const(x.arg))), # if has a (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),