mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add InvalidType to ConstType [pr] (#14373)
* add InvalidType to ConstType [pr] TYPED=1 python test/test_tiny.py passes. added PyConst = float|int|bool for some Tensor level input types * hcq
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -241,7 +241,7 @@ jobs:
|
||||
python -m mypy --lineprecision-report .
|
||||
cat lineprecision.txt
|
||||
- name: Run TYPED=1
|
||||
run: TYPED=1 python -c "import tinygrad"
|
||||
run: CHECK_OOB=0 DEV=CPU TYPED=1 python test/test_tiny.py
|
||||
|
||||
unittest:
|
||||
name: Unit Tests
|
||||
|
||||
@@ -33,7 +33,8 @@ class InvalidType:
|
||||
|
||||
Invalid = InvalidType()
|
||||
|
||||
ConstType = float|int|bool
|
||||
PyConst = float|int|bool
|
||||
ConstType = PyConst|InvalidType
|
||||
|
||||
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
||||
|
||||
@@ -148,7 +149,7 @@ class dtypes:
|
||||
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
||||
@staticmethod
|
||||
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
|
||||
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
|
||||
@@ -136,7 +136,7 @@ class MathMixin:
|
||||
return self._binop(Ops.MOD, x, reverse)
|
||||
|
||||
def sub(self, x: Self | ConstType, reverse: bool = False):
|
||||
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, -self.ufix(x))
|
||||
|
||||
def div(self, x: Self | ConstType, reverse: bool = False):
|
||||
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
|
||||
@@ -243,7 +243,7 @@ class MathMixin:
|
||||
return self.alu(Ops.MAX, self.ufix(x))
|
||||
|
||||
def minimum(self, x: Self | ConstType):
|
||||
return -(-self).maximum(-x)
|
||||
return -(-self).maximum(-self.ufix(x))
|
||||
|
||||
def where(self, x: Self | ConstType, y: Self | ConstType):
|
||||
if isinstance(x, type(self)):
|
||||
|
||||
@@ -593,7 +593,7 @@ class AMDProgram(HCQProgram):
|
||||
base=self.lib_gpu.va_addr)
|
||||
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev)
|
||||
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
if self.dev.pmc_enabled:
|
||||
|
||||
@@ -73,7 +73,7 @@ class CPUProgram(HCQProgram):
|
||||
try: rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or WIN else 'libgcc_s.so.1')
|
||||
except OSError: pass
|
||||
|
||||
def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, tuple[int, int]]|None=None, **kwargs):
|
||||
def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, int]|None=None, **kwargs):
|
||||
self.runtimevars = runtimevars or {}
|
||||
|
||||
LVP = isinstance(dev.renderer, LVPRenderer)
|
||||
|
||||
@@ -307,7 +307,7 @@ class NVProgram(HCQProgram):
|
||||
yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz
|
||||
start_off += (sz if typ == 0x4 else 0) + 4
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
|
||||
raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}")
|
||||
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
|
||||
|
||||
@@ -257,7 +257,7 @@ class QCOMProgram(HCQProgram):
|
||||
super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
|
||||
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
|
||||
if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])):
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
|
||||
@@ -283,20 +283,21 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t))
|
||||
|
||||
class HCQArgsState(Generic[ProgramType]):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=()):
|
||||
self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
|
||||
self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
|
||||
|
||||
def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
|
||||
|
||||
class CLikeArgsState(HCQArgsState[ProgramType]):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=(), prefix:list[int]|None=None):
|
||||
super().__init__(buf, prg, bufs, vals=vals)
|
||||
|
||||
if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
|
||||
|
||||
self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4)
|
||||
self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
|
||||
assert None not in vals
|
||||
self.bind_sints_to_buf(*cast(tuple[sint, ...], vals), buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
|
||||
|
||||
class HCQProgram(Generic[HCQDeviceType]):
|
||||
def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
|
||||
@@ -307,7 +308,7 @@ class HCQProgram(Generic[HCQDeviceType]):
|
||||
@staticmethod
|
||||
def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
|
||||
|
||||
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
|
||||
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int|None, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
|
||||
"""
|
||||
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
||||
Args:
|
||||
@@ -322,7 +323,7 @@ class HCQProgram(Generic[HCQDeviceType]):
|
||||
return self.args_state_t(argsbuf, self, bufs, vals=vals)
|
||||
|
||||
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
|
||||
vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None:
|
||||
"""
|
||||
Enqueues the program for execution with the given arguments and dimensions.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
@@ -112,7 +112,7 @@ class Tensor(OpMixin):
|
||||
__slots__ = "uop", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
def __init__(self, data:PyConst|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
@@ -138,7 +138,7 @@ class Tensor(OpMixin):
|
||||
data = data.replace(src=(var.replace(src=const.src), const))
|
||||
elif data is None:
|
||||
data = Tensor(0, device=_device, dtype=_dtype or dtypes.default_float, requires_grad=requires_grad).uop
|
||||
elif isinstance(data, get_args(ConstType)):
|
||||
elif isinstance(data, get_args(PyConst)):
|
||||
data = (UOp.unique_const if _force_unique or requires_grad else UOp.const)(_dtype or dtypes.from_py(data), data, _device)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
@@ -321,7 +321,7 @@ class Tensor(OpMixin):
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
return self._buffer().as_typed_buffer(self.shape)
|
||||
|
||||
def item(self) -> ConstType:
|
||||
def item(self) -> PyConst:
|
||||
"""
|
||||
Returns the value of this tensor as a standard Python number.
|
||||
|
||||
@@ -334,7 +334,7 @@ class Tensor(OpMixin):
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
|
||||
def tolist(self) -> ConstType|list[Any]:
|
||||
def tolist(self) -> PyConst|list[Any]:
|
||||
"""
|
||||
Returns the value of this tensor as a nested list.
|
||||
Returns single value for const tensor.
|
||||
@@ -618,7 +618,7 @@ class Tensor(OpMixin):
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
||||
def full(shape:tuple[sint, ...], fill_value:PyConst, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with the given value.
|
||||
|
||||
@@ -748,7 +748,7 @@ class Tensor(OpMixin):
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
|
||||
|
||||
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
||||
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the same shape as `self`, filled with the given value.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
@@ -1268,12 +1268,12 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
||||
def __setitem__(self, indices, v:Tensor|PyConst) -> None:
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
self.realize()._getitem(indices).assign(v)
|
||||
return
|
||||
# NOTE: check that setitem target is valid first
|
||||
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if isinstance(v, get_args(PyConst)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
self.realize()
|
||||
@@ -1544,7 +1544,7 @@ class Tensor(OpMixin):
|
||||
for i, s in enumerate(self.shape)], dim=-1)
|
||||
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim)).reshape(-1, self.ndim)
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor:
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|PyConst) -> Tensor:
|
||||
"""
|
||||
Replaces `self` with `value` wherever the elements of `mask` are True.
|
||||
|
||||
@@ -1864,7 +1864,7 @@ class Tensor(OpMixin):
|
||||
|
||||
# https://keccak.team/keccak_specs_summary.html
|
||||
|
||||
def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64):
|
||||
def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64):
|
||||
# TODO: contiguous is here for compile speed
|
||||
return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous()
|
||||
rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
|
||||
@@ -2621,7 +2621,7 @@ class Tensor(OpMixin):
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
return src, mask
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
||||
def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
||||
"""
|
||||
Scatters `src` values along an axis specified by `dim`.
|
||||
Apply `add` or `multiply` reduction operation with `reduce`.
|
||||
@@ -2686,7 +2686,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
src, mask = self._pre_scatter(dim, index, src)
|
||||
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b)
|
||||
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
|
||||
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
|
||||
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
|
||||
|
||||
@@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace, ConstFloat
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
||||
@@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x:UOp|bool, default:bool=True):
|
||||
@@ -720,7 +720,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]:
|
||||
def pop_const(self, op=Ops.ADD) -> tuple[UOp, PyConst]: # NOTE: assume Invalid ALU is resolved
|
||||
return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype))
|
||||
@staticmethod
|
||||
def gcd(*uops: UOp) -> UOp:
|
||||
@@ -740,11 +740,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
|
||||
def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)
|
||||
@property
|
||||
def vmin(self) -> ConstType: return self._min_max[0]
|
||||
def vmin(self) -> PyConst: return self._min_max[0]
|
||||
@property
|
||||
def vmax(self) -> ConstType: return self._min_max[1]
|
||||
def vmax(self) -> PyConst: return self._min_max[1]
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> tuple[ConstType, ConstType]:
|
||||
def _min_max(self) -> tuple[PyConst, PyConst]:
|
||||
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
|
||||
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
@@ -943,7 +943,7 @@ class UPat(OpMixin):
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# lil helper
|
||||
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
|
||||
@@ -1451,4 +1451,4 @@ def pyrender(ast:UOp) -> str:
|
||||
sint = int|UOp
|
||||
Variable = UOp
|
||||
|
||||
ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...]
|
||||
ConstLike = ConstType|Variable|tuple[ConstType, ...]
|
||||
|
||||
Reference in New Issue
Block a user