add UOps.VALID (#6387)

* uops valid

* broke full_shape

* fixup that st (hardcoded asts still red)

* fixup DEFINE_VAR

debug

more debug

* start moving stuff to ast_const

* move test_linearizer

* move test_linearizer_failures to ast_const

* fixup test_schedule

* small diff change

* regenerate dataset

* fixup test_multitensor

* regen dataset try 2

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2024-09-09 16:58:43 +08:00
committed by GitHub
parent e1d61b048b
commit 8186e4e7d6
9 changed files with 23 additions and 20 deletions

View File

@@ -64,8 +64,8 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype))
st_src = st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
T = TypeVar("T")
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: