From 3c12244cfc47c3dc4597b2d65d135866fd49fd90 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 4 Oct 2024 22:49:21 -0400 Subject: [PATCH] remove DTypeLike from lazy (#6897) keep only in tensor --- tinygrad/engine/lazy.py | 8 ++++---- tinygrad/tensor.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index d274981627..b7c954c9e4 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Union, Optional, Any, Tuple, List, get_args -from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype +from tinygrad.dtype import dtypes, DType, ConstType, to_dtype from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU from tinygrad.ops import identity_element, MathTrait, resolve, UOp @@ -10,7 +10,7 @@ from tinygrad.device import Buffer from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() -def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), +def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))): if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None dtype = to_dtype(dtype) @@ -25,7 +25,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DSP", "DISK"} class LazyBuffer(MathTrait): - def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike, + def __init__(self, device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None): self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata @@ -68,7 +68,7 @@ class LazyBuffer(MathTrait): def lbs(self) -> List[LazyBuffer]: return [self] @staticmethod - def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer: + def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer: assert isinstance(src, tuple) return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e124626319..13cdac3b5b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -389,10 +389,11 @@ class Tensor: @staticmethod def _metaop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): + dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): - return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \ - for d in device], None), device, dtype, **kwargs) - return Tensor(LazyBuffer.metaop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs) + return Tensor(MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None), + device, dtype, **kwargs) + return Tensor(LazyBuffer.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod def empty(*shape, **kwargs):