mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
do not construct unmasked VALID (#8759)
* new lines that exist in codegen/ops * update tests * update sops.gz (13071 -> 13070 asts) * fix viz too * remove that TODO * diff pruning * mask assert + device * work * diff pruning * re: fix viz too --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import T
|
||||
from tinygrad.helpers import T, unwrap
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.rewriter import full_graph_rewrite
|
||||
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
|
||||
@@ -43,7 +43,9 @@ def rand_for_dtype(dt:DType, size:int):
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
|
||||
if st_src is None:
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
st = unwrap(st_src[0].st)
|
||||
if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),))
|
||||
return UOp.const(dtype, val).valid(st)
|
||||
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
st = time.perf_counter_ns()
|
||||
|
||||
Reference in New Issue
Block a user