From fe707bc96819e5662df3dd1f3e1a69e57c1c98ec Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 24 Jun 2024 21:12:34 +0300 Subject: [PATCH] hotfix: don't use is for comparing dtype (#5128) --- tinygrad/codegen/uops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 6217f1a4ae..326265bb71 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -462,7 +462,7 @@ class UOpGraph: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype if uop in (UOps.CONST, UOps.DEFINE_ACC): if uop is UOps.DEFINE_ACC: - assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" + assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg