diff --git a/extra/datasets/sops.gz b/extra/datasets/sops.gz index f98d4cfc0a..bce0f6f7b2 100644 Binary files a/extra/datasets/sops.gz and b/extra/datasets/sops.gz differ diff --git a/extra/optimization/extract_dataset.py b/extra/optimization/extract_dataset.py index 174c276e37..595ddb7dfe 100755 --- a/extra/optimization/extract_dataset.py +++ b/extra/optimization/extract_dataset.py @@ -7,7 +7,7 @@ from test.external.process_replay.process_replay import _run_differ PAGE_SIZE = 100 RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" -LOGOPS = os.getenv("LOGOPS", "/tmp/ops") +LOGOPS = os.getenv("LOGOPS", "/tmp/sops") def extract_ast(offset:int): logops = open(LOGOPS, "a") diff --git a/test/helpers.py b/test/helpers.py index 380186544e..b441359cd1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -64,8 +64,8 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None: def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp: if st_src is None: - st_src = st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(), - return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0)) + st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),) + return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype)) T = TypeVar("T") def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: diff --git a/test/test_multitensor.py b/test/test_multitensor.py index abccdfe335..b0e9fcca96 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase): assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.ADD assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0] - assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1 + assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.MUL - assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2 + assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2 assert ast.src[2].src[1].op is UOps.LOAD t = t + t.full_like(3) for si in t.schedule(): @@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase): assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.ADD assert ast.src[2].src[0].op is UOps.LOAD - assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3 + assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3 def test_shard_memory(self): devices = (d0, d1, d2, d3) diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index ceb338bda3..ac4cab29a5 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase): uop_sts = verify_ast(a.schedule()[-1].ast) store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0] + const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) if __name__ == '__main__': diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0966876982..3df849cfc2 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -639,7 +639,7 @@ class Kernel: # for locals, we use the ShapeTracker that's in the srcs st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() - if op.op is UOps.VALID: return replace(op, src=(st_uop,)) + if op.op is UOps.CONST: return replace(op, src=(st_uop,)) if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) if op.op is UOps.REDUCE_AXIS: @@ -777,8 +777,8 @@ class Kernel: def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None: if uop in sts: return op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg - # NOTE: UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, UOps.CONST and UOps.DEFINE_VAR don't have shape - if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}: return + # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape + if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return # restore globals from the two stage reduce if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL: _assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts) @@ -791,8 +791,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" # movementops are pushed to the edges with SHAPETRACKER # elementwise inherits shape - st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]] - for x in (src[0:1] if len(src) and src[0].op is UOps.VALID else src[1:] if op in BUFFER_UOPS else src): + st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]] + for x in (src[1:] if op in BUFFER_UOPS else src): if sts[x].shape != st.shape: if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}") raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index a7a8bb6282..e42f8d3e36 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -91,9 +91,9 @@ class IndependentLowerer: def _to_uop(self, x:UOp) -> UOp: if x.op in BUFFER_UOPS: idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs) - if x.op is UOps.VALID: return valid # TODO: check has_valid in UPat, not here has_valid = valid.op is not UOps.CONST or valid.arg is not True + if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0)) buf = x.src[0] if x.op is UOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else () diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 58d3ecccc0..c6efe3af4a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -66,7 +66,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. val, var_val = val.unbind() var_vals[val] = var_val else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" - return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0)) + return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val) # otherwise, it's a load and we add it to the inputs if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): @@ -136,9 +136,8 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]: if (st:=uop_sts.get(uop)): return st if uop.op in BUFFER_UOPS: return uop.st_arg - src = [x for x in uop.src if x.op not in {UOps.CONST, UOps.DEFINE_VAR}] - src_sts = [xst for x in src if (xst:=get_output_st(x, uop_sts)) is not None] - if len(src_sts) != len(src) or not all_same([x.shape for x in src_sts]): return None + src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None] + if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0] return st diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 78d2d503c7..4c4cedb953 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -319,9 +319,8 @@ class UOps(Enum): # ops that are not graph nodes ENDRANGE = auto() ENDIF = auto() - VALID = auto() -BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID} +BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} @@ -349,7 +348,7 @@ class UOp(MathTrait): self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) # *** uop syntactic sugar @property - def st_loc(self) -> int: return 0 if self.op is UOps.VALID else 1 + def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1 @property def st_arg(self) -> ShapeTracker: assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}" @@ -385,12 +384,11 @@ class UOp(MathTrait): def full_shape(self) -> Tuple[sint, ...]: if self.op is UOps.SHAPETRACKER: return self.arg.shape # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape - return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, \ - UOps.CONST, UOps.DEFINE_VAR}])) + return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}])) def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) def variables(self) -> List[Variable]: st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr) + return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr) def const_factor(self) -> int: """largest known int that divides self""" if self.op is UOps.CONST: return self.arg