diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 6217f1a4ae..326265bb71 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -462,7 +462,7 @@ class UOpGraph: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype if uop in (UOps.CONST, UOps.DEFINE_ACC): if uop is UOps.DEFINE_ACC: - assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" + assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg