hotfix: don't use is for comparing dtype (#5128)

This commit is contained in:
qazal
2024-06-24 21:12:34 +03:00
committed by GitHub
parent dfa562dbc1
commit fe707bc968

View File

@@ -462,7 +462,7 @@ class UOpGraph:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg