Tuple -> tuple, List -> list [pr] (#8936)

This commit is contained in:
chenyu
2025-02-06 14:21:19 -05:00
committed by GitHub
parent d5183e1584
commit a092b6395d
9 changed files with 43 additions and 47 deletions

View File

@@ -1,5 +1,5 @@
import time
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.ops import UOp, Ops, sint
@@ -40,14 +40,14 @@ def rand_for_dtype(dt:DType, size:int):
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
def ast_const(dtype:DType, val:ConstType, shape:tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
st = unwrap(st_src[0].st)
if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),))
return UOp.const(dtype, val).valid(st)
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]:
st = time.perf_counter_ns()
ret = fxn(*args, **kwargs)
return ret, (time.perf_counter_ns()-st)*1e-6