do not construct unmasked VALID (#8759)

* new lines that exist in codegen/ops

* update tests

* update sops.gz (13071 -> 13070 asts)

* fix viz too

* remove that TODO

* diff pruning

* mask assert + device

* work

* diff pruning

* re: fix viz too

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal
2025-01-28 13:51:21 -05:00
committed by GitHub
parent 3417bc1814
commit ba17786068
11 changed files with 58 additions and 65 deletions

View File

@@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T
from tinygrad.helpers import T, unwrap
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.rewriter import full_graph_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
@@ -43,7 +43,9 @@ def rand_for_dtype(dt:DType, size:int):
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
st = unwrap(st_src[0].st)
if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),))
return UOp.const(dtype, val).valid(st)
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
st = time.perf_counter_ns()