LazyBuffer.const type check cleanup [run_process_replay] (#5795)

This commit is contained in:
chenyu
2024-07-29 16:17:14 -04:00
committed by GitHub
parent 2cadf21684
commit fc393d710d

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
@@ -71,7 +71,7 @@ class LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
shape = self.shape if shape is None else shape
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)