mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove UOp.metaop [pr] (#10664)
* little simpler UOp.const_like [pr] * remove UOp.metaop * bind * remove * min diff * that comment is fine
This commit is contained in:
@@ -138,9 +138,13 @@ class Tensor(MathTrait):
|
||||
# create a UOp from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
if data.op is Ops.BIND: data = UOp.metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
|
||||
elif data is None: data = UOp.metaop(Ops.CONST, (0,), dtype or dtypes.default_float, device, arg=0)
|
||||
elif isinstance(data, get_args(ConstType)): data = UOp.metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
if data.op is Ops.BIND:
|
||||
var, val = data.unbind()
|
||||
# give the bound constant a device
|
||||
const = UOp.const(var.dtype, val, device, ())
|
||||
data = data.replace(src=(var.replace(src=const.src), const))
|
||||
elif data is None: data = UOp.const(dtype or dtypes.default_float, 0, device, ())
|
||||
elif isinstance(data, get_args(ConstType)): data = UOp.const(dtype or dtypes.from_py(data), data, device, ())
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if dtype is None:
|
||||
@@ -151,7 +155,7 @@ class Tensor(MathTrait):
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
if data.shape == (): data = UOp.metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
if data.shape == (): data = UOp.const(dtype or _from_np_dtype(data.dtype), data.item(), device, ())
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
||||
elif isinstance(data, pathlib.Path):
|
||||
dtype = dtype or dtypes.uint8
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args, Sequence
|
||||
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Sequence
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
@@ -378,8 +378,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
if self._device is not None or self.st is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self._device, b)
|
||||
return UOp.const(self.dtype, b)
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -407,10 +406,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(arg, out_dtype, (self,)+src)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike):
|
||||
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
if shape is not None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),))
|
||||
if device is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
return ret
|
||||
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
@@ -471,18 +476,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
@staticmethod
|
||||
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str|tuple[str, ...]|None=None, arg=None) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
||||
if op is Ops.CONST:
|
||||
assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
|
||||
return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, () if device is None else (UOp(Ops.DEVICE, arg=device),),
|
||||
ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)),))
|
||||
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
|
||||
assert op is Ops.BIND, f"unknown op {op}"
|
||||
var, val = arg.unbind()
|
||||
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
||||
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
||||
assert arg is None or isinstance(self.device, tuple)
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
|
||||
|
||||
Reference in New Issue
Block a user