mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
LazyBuffer.const type check cleanup [run_process_replay] (#5795)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
|
||||
@@ -71,7 +71,7 @@ class LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
||||
|
||||
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
||||
assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
|
||||
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
||||
shape = self.shape if shape is None else shape
|
||||
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user