From fc393d710de707d4a9b9b417a29486403f8cc217 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 29 Jul 2024 16:17:14 -0400 Subject: [PATCH] LazyBuffer.const type check cleanup [run_process_replay] (#5795) --- tinygrad/lazy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 86f16877e7..cfac0ce7d8 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Union, Optional, Any, Tuple, List +from typing import Union, Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st @@ -71,7 +71,7 @@ class LazyBuffer: return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache) def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer: - assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType" + assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType" shape = self.shape if shape is None else shape return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)