diff --git a/extra/datasets/sops.gz b/extra/datasets/sops.gz index 181fc25b5c..b476386852 100644 Binary files a/extra/datasets/sops.gz and b/extra/datasets/sops.gz differ diff --git a/test/helpers.py b/test/helpers.py index 847b060320..a66c5f2104 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import ConstType, DType from tinygrad.nn.state import get_parameters -from tinygrad.helpers import T +from tinygrad.helpers import T, unwrap from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.rewriter import full_graph_rewrite from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator @@ -43,7 +43,9 @@ def rand_for_dtype(dt:DType, size:int): def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp: if st_src is None: st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),) - return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0)) + st = unwrap(st_src[0].st) + if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),)) + return UOp.const(dtype, val).valid(st) def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: st = time.perf_counter_ns() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 55b4463a58..a65559b2e0 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -121,6 +121,8 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(4,).realize() helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1*-1], opts=[]) + # shapeless CONST in AST is not supported + @unittest.expectedFailure def test_const_alu_indexing_one_const_fine(self): st = ShapeTracker.from_shape((4,)).to_uop() load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 53c6ca5e5d..3dffc05c63 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1206,18 +1206,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.RECIP, dtypes.half, arg=None, src=( UOp(Ops.ADD, dtypes.half, arg=None, src=( - UOp(Ops.WHERE, dtypes.half, arg=None, src=( - x6:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.CONST, dtypes.half, arg=1.0, src=()), - x9:=UOp(Ops.CONST, dtypes.half, arg=0.0, src=()),)), + UOp(Ops.CONST, dtypes.half, arg=1.0, src=( + x6:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.EXP2, dtypes.half, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.WHERE, dtypes.half, arg=None, src=( - x6, - UOp(Ops.CONST, dtypes.half, arg=2.0, src=()), - x9,)), + UOp(Ops.CONST, dtypes.half, arg=2.0, src=( + x6,)), UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( @@ -1232,10 +1227,8 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.WHERE, dtypes.half, arg=None, src=( - x6, - UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()), - x9,)),)),)),)),)),)),)) + UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=( + x6,)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @@ -1283,17 +1276,14 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.WHERE, dtypes.int, arg=None, src=( UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=1, src=()), - x20:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),)), - UOp(Ops.WHERE, dtypes.int, arg=None, src=( - x22:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=-1, src=()), - x20,)),)),)), - UOp(Ops.WHERE, dtypes.bool, arg=None, src=( - x22, - UOp(Ops.CONST, dtypes.bool, arg=True, src=()), - UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) + UOp(Ops.CONST, dtypes.int, arg=1, src=( + x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 50000), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.int, arg=0, src=( + x20,)),)),)), + UOp(Ops.CONST, dtypes.int, arg=-1, src=( + x23:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.CONST, dtypes.bool, arg=True, src=( + x23,)),)),)),)),)),)),)) opts = [Opt(op=OptOps.GROUPTOP, axis=1, arg=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["AMD", "GPU", "METAL", "NV", "CUDA"]) @@ -1348,11 +1338,8 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( - x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()), - x10,)), + x7:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=( + x8:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( @@ -1364,20 +1351,18 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x8, - UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), - x10,)),)),)), + x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.float, arg=-1.0, src=( + x8,)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - x22,)),)), + x20,)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - x22,)),)), + x20,)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - x22,)),)), + x20,)),)), x7,)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), @@ -1394,11 +1379,8 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( - x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()), - x10,)), + x7:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=( + x8:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( @@ -1410,20 +1392,18 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x8, - UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), - x10,)),)),)), + x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.float, arg=-1.0, src=( + x8,)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - x22,)),)), + x20,)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - x22,)),)), + x20,)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - x22,)),)), + x20,)),)), x7,)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), diff --git a/test/test_schedule.py b/test/test_schedule.py index 1b4246f470..157a3d7285 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2026,8 +2026,7 @@ class TestSwizzle(unittest.TestCase): self.assertEqual(swizzle_cnt(ret), 1) def store_val(si:ScheduleItem): return si.ast.src[0].src[2] -# TODO: we only need valid on ast consts if it's masked, can fold this early to UOp.const -zero_pm = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat(Ops.CONST, arg=0), UPat.cvar())) +zero_pm = UPat(Ops.CONST, arg=0) class TestView(unittest.TestCase): def test_all_masked_out(self): # start with non CONST Ops @@ -2193,7 +2192,7 @@ class TestConst(unittest.TestCase): a = Tensor.ones((4,)).contiguous() sched = a.schedule() print(sched[0].ast) - const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),)) + const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.CONST)),)) self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1) run_schedule(sched) self.assertListEqual(a.tolist(), [1, 1, 1, 1]) diff --git a/test/test_uops.py b/test/test_uops.py index f050c72dd9..3dbb70ace5 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -524,7 +524,6 @@ class TestShapeSpec(unittest.TestCase): a = Tensor.ones((4, 4)).lazydata self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4))) - @unittest.expectedFailure def test_padded_const(self): a = Tensor.ones((1, 1)).pad(((1, 1), (1, 1))) ast = a.contiguous().schedule()[0].ast diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index 255c2fab33..bf23847e4e 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -78,7 +78,7 @@ class TestVerifyAST(unittest.TestCase): verify_ast(ast:=a.schedule()[-1].ast) store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0] + const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) def test_assert_swizzle(self): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 7ce68968c0..a20455492e 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -582,7 +582,10 @@ class Kernel: ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) if op.op in GroupOp.Buffer and op in self.bufs: st_uop = self.sts[self.bufs.index(op)].to_uop() - return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:])) + # NOTE: if CONST got masked after applying opts, we create a new VALID + if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st)) + # otherwise we just replace the VIEW source + return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:])) if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)) if op.op is Ops.REDUCE_AXIS: reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e4e025176b..c51b3606e3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -194,12 +194,12 @@ to_si = PatternMatcher([ # don't need contiguous or assign anymore (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), + # don't need DEVICE anymore + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), # PRELOAD becomes LOAD (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), - # CONST(VIEW) becomes VALID too, TODO: doesn't have to - (UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2e76b8b109..bfe532fee6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -162,7 +162,7 @@ class GroupOp: Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE} - Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} + Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR} Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} # BinaryOps that can be flipped @@ -305,7 +305,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def full_shape(self) -> tuple[sint, ...]: if self.op is Ops.VIEW: return self.shape # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops - return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}])) + return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \ + # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this + and not (x.op is Ops.CONST and x.st is None)])) @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @property @@ -385,7 +387,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)) def valid(self, st:ShapeTracker): assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}" - return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self, 0) + from tinygrad.shape.shapetracker import ShapeTracker + # NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked + unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop() + return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,))) @staticmethod def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx) def _reduce_op(self, op:Ops, axis:tuple[int, ...]): @@ -1330,7 +1335,7 @@ merge_views = PatternMatcher([ # push VIEW to parents view_left = merge_views+PatternMatcher([ # VIEW(CONST) becomes VALID - (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)), # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 13f4fa0915..437ca31d1a 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -42,7 +42,10 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]: graph: dict[int, tuple[str, list[int], str]] = {} excluded: set[UOp] = set() for u in (toposort:=x.toposort): - if u.op in {Ops.CONST, Ops.DEVICE}: excluded.update((u,) + u.src) + # always exclude DEVICE/CONST + if u.op in {Ops.DEVICE, Ops.CONST}: excluded.add(u) + # only exclude CONST VIEW source if it has no other children + if u.op is Ops.CONST and len(u.src) != 0 and all((cr:=c()) is None or cr.op is Ops.CONST for c in u.src[0].children): excluded.update(u.src) for u in toposort: if u in excluded: continue argst = str(u.arg)