remove UOp.metaop [pr] (#10664)

* little simpler UOp.const_like [pr]

* remove UOp.metaop

* bind

* remove

* min diff

* that comment is fine
This commit is contained in:
qazal
2025-06-06 16:21:48 +03:00
committed by GitHub
parent 4a6d84c4c3
commit 5170f387b3
2 changed files with 18 additions and 21 deletions

View File

@@ -138,9 +138,13 @@ class Tensor(MathTrait):
# create a UOp from the different types of inputs
if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
if data.op is Ops.BIND: data = UOp.metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
elif data is None: data = UOp.metaop(Ops.CONST, (0,), dtype or dtypes.default_float, device, arg=0)
elif isinstance(data, get_args(ConstType)): data = UOp.metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
if data.op is Ops.BIND:
var, val = data.unbind()
# give the bound constant a device
const = UOp.const(var.dtype, val, device, ())
data = data.replace(src=(var.replace(src=const.src), const))
elif data is None: data = UOp.const(dtype or dtypes.default_float, 0, device, ())
elif isinstance(data, get_args(ConstType)): data = UOp.const(dtype or dtypes.from_py(data), data, device, ())
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
@@ -151,7 +155,7 @@ class Tensor(MathTrait):
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = UOp.metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
if data.shape == (): data = UOp.const(dtype or _from_np_dtype(data.dtype), data.item(), device, ())
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
dtype = dtype or dtypes.uint8

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args, Sequence
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Sequence
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
@@ -378,8 +378,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
if self._device is not None or self.st is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self._device, b)
return UOp.const(self.dtype, b)
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -407,10 +406,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(arg, out_dtype, (self,)+src)
@staticmethod
def const(dtype:DType, b:ConstLike):
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
if shape is not None:
from tinygrad.shape.shapetracker import ShapeTracker
ret = ret.replace(src=(ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),))
if device is not None:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
return ret
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
@@ -471,18 +476,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** from LazyBuffer ***
@staticmethod
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str|tuple[str, ...]|None=None, arg=None) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
if op is Ops.CONST:
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, () if device is None else (UOp(Ops.DEVICE, arg=device),),
ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)),))
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
assert op is Ops.BIND, f"unknown op {op}"
var, val = arg.unbind()
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)