split out unique_const and cache const [pr] (#13493)

* split out unique_const

* add cache to const

* call const in unique_const
This commit is contained in:
George Hotz
2025-11-29 10:44:28 -08:00
committed by GitHub
parent c38b7684dc
commit 6a140f74fe
4 changed files with 23 additions and 14 deletions

View File

@@ -72,6 +72,7 @@ if __name__ == "__main__":
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()
fold_divmod_general.cache_clear()
UOp.const.cache_clear()
Tensor._device_seeds.clear()
Tensor._device_rng_counters.clear()

View File

@@ -128,6 +128,8 @@ class dtypes:
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(val, float) and math.isnan(val): val = math.nan
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.cache

View File

@@ -134,8 +134,10 @@ class Tensor(OpMixin):
# give the bound constant a device
const = UOp.const(var.dtype, val, _device, ())
data = data.replace(src=(var.replace(src=const.src), const)) # type: ignore
elif data is None: data = UOp.const(_dtype or dtypes.default_float, 0, _device, (), unique=_force_unique)
elif isinstance(data, get_args(ConstType)): data = UOp.const(_dtype or dtypes.from_py(data), data, _device, (), unique=_force_unique)
elif data is None:
data = (UOp.unique_const if _force_unique else UOp.const)(_dtype or dtypes.default_float, 0, _device)
elif isinstance(data, get_args(ConstType)):
data = (UOp.unique_const if _force_unique else UOp.const)(_dtype or dtypes.from_py(data), data, _device)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
elif isinstance(data, (list, tuple)):
if _dtype is None:
@@ -146,8 +148,10 @@ class Tensor(OpMixin):
elif is_numpy_ndarray(data):
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = UOp.const(_dtype or _from_np_dtype(data.dtype), data.item(), _device, (), unique=_force_unique)
else: data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data) # type: ignore [name-defined]
if data.shape == ():
data = (UOp.unique_const if _force_unique else UOp.const)(_dtype or _from_np_dtype(data.dtype), data.item(), _device)
else:
data = _fromnp(data.astype(npdtype) if _dtype is not None and (npdtype:=_to_np_dtype(_dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
_dtype = _dtype or dtypes.uint8
data = UOp.new_buffer(f"DISK:{data.resolve()}", data.stat().st_size // _dtype.itemsize, _dtype)

View File

@@ -429,20 +429,22 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, (self,)+src, **kwargs)
@staticmethod
@functools.cache
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, unique:bool|int=False):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a VCONST if they are all the same
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(b, float) and math.isnan(b): b = math.nan
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
if device is not None:
if unique or not isinstance(unique, bool): ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device), UOp.unique(None if unique is True else unique)))
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
elif unique or not isinstance(unique, bool): raise RuntimeError("unique consts only with DEVICE")
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
arg=dtypes.as_const(b, dtype),
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
@staticmethod
def unique_const(dtype:DType, b:ConstType, device:str|tuple[str, ...], unique=True):
# NOTE: b is ConstType, not ConstLike, so UOps and tuples aren't allowed
assert not isinstance(b, (UOp, tuple)), "unique const only works on numbers"
ret = UOp.const(dtype, b, device)
return ret.replace(src=ret.src + (UOp.unique(None if unique is True else unique),))
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@@ -1359,7 +1361,7 @@ sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
lambda x,d,u: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
lambda x,d,u: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: