diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 86f16877e7..cfac0ce7d8 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Union, Optional, Any, Tuple, List +from typing import Union, Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st @@ -71,7 +71,7 @@ class LazyBuffer: return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache) def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer: - assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType" + assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType" shape = self.shape if shape is None else shape return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)