diff --git a/.github/actions/setup-tinygrad/action.yml b/.github/actions/setup-tinygrad/action.yml index 76323bc415..0b2dbc05a5 100644 --- a/.github/actions/setup-tinygrad/action.yml +++ b/.github/actions/setup-tinygrad/action.yml @@ -302,4 +302,4 @@ runs: - name: Install mesa (macOS) if: inputs.mesa == 'true' && runner.os == 'macOS' shell: bash - run: brew install sirhcm/tinymesa/tinymesa + run: brew install sirhcm/tinymesa/tinymesa_cpu diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c563f79c62..39893f4388 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -626,11 +626,11 @@ jobs: - name: benchmark openpilot 0.9.9 dmonitoring run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx - name: openpilot compile3 0.10.1 driving_vision - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: openpilot compile3 0.10.1 driving_policy - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.10.1 dmonitoring - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx - name: benchmark MobileNetV2 on DSP run: | # generate quantized weights diff --git a/autogen_stubs.sh b/autogen_stubs.sh index 5d02cd37f4..58d919d597 100755 --- a/autogen_stubs.sh +++ b/autogen_stubs.sh @@ -520,17 +520,17 @@ generate_mesa() { LVP_NIR_OPTIONS=$(./extra/mesa/lvp_nir_options.sh $MESA_SRC) fixup $BASE/mesa.py - patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "f'{brew_prefix()}/lib/libtinymesa_cpu.dylib'" + patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "brew_path('tinymesa_cpu')" "brew_path('tinymesa')" echo "lvp_nir_options = gzip.decompress(base64.b64decode('$LVP_NIR_OPTIONS'))" >> $BASE/mesa.py cat <> $BASE/mesa.py + echo "def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/mesa.py # bitfield bug in clang2py sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/mesa.py diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 02b8496b26..1c831aa48d 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -121,6 +121,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file, tol): print("test vs onnx passed") return timings +def bench(run, inputs): + from extra.bench_log import WallTimeEvent, BenchEvent + for _ in range(10): + with WallTimeEvent(BenchEvent.STEP): + run(**inputs).numpy() + if __name__ == "__main__": onnx_file = fetch(OPENPILOT_MODEL) inputs, outputs = compile(onnx_file) @@ -131,3 +137,5 @@ if __name__ == "__main__": if not getenv("FLOAT16"): test_vs_onnx(inputs, outputs, onnx_file, 1e-4) + if getenv("BENCHMARK_LOG", ""): + bench(pickle_loaded, inputs) diff --git a/examples/sdv2.py b/examples/sdv2.py index 29b1abb8fd..856cf239ad 100644 --- a/examples/sdv2.py +++ b/examples/sdv2.py @@ -99,6 +99,7 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--noshow', action='store_true', help="Don't show the image") parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() N = 1 @@ -112,19 +113,22 @@ if __name__ == "__main__": model = StableDiffusionV2(**params) - default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' - weights_fn = args.weights_fn - if not weights_fn: - weights_url = args.weights_url if args.weights_url else default_weights_url - weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) - with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - load_state_dict(model, safe_load(weights_fn), strict=False) + if not args.fakeweights: + default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' + weights_fn = args.weights_fn + if not weights_fn: + weights_url = args.weights_url if args.weights_url else default_weights_url + weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) + + load_state_dict(model, safe_load(weights_fn), strict=False) if args.fp16: for k,v in get_state_dict(model).items(): if k.startswith("model"): - v.replace(v.cast(dtypes.float16).realize()) + v.replace(v.cast(dtypes.float16)) + + Tensor.realize(*get_state_dict(model).values()) c = { "crossattn": model.cond_stage_model(args.prompt) } uc = { "crossattn": model.cond_stage_model("") } diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 644c524476..4650b7e1d9 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -263,14 +263,16 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--seed', type=int, help="Set the random latent seed") parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() model = StableDiffusion() # load in weights with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') - load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) + if not args.fakeweights: + model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') + load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) if args.fp16: for k,v in get_state_dict(model).items(): diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 3dce947828..1d0b223506 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -2,7 +2,8 @@ from extra.models.resnet import ResNet50 from tinygrad import Tensor, nn, Device from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops -from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer +from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites +from tinygrad.codegen.late.control_flow import linearize from tinygrad.uop.spec import type_verify if __name__ == "__main__": @@ -39,7 +40,7 @@ if __name__ == "__main__": with Timing("***** model linearize in "): uops_line = [] for u in rewritten_uops: - uops_line.append(apply_rewrites(u, rewrites_for_linearizer)) + uops_line.append(linearize(u)) with Timing("***** model verify in "): - for u in uops_line: type_verify(u.arg.lst) - print(sum(len(u.arg.lst) for u in uops_line)) + for u in uops_line: type_verify(u) + print(sum(len(u) for u in uops_line)) diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py index 8b04d4af2b..ec669781ba 100644 --- a/test/external/speed_v_theoretical.py +++ b/test/external/speed_v_theoretical.py @@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase): # theoretical is nv_tflops=165, amd_tflops=123 def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=65) - def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=60) + def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=115, amd_tflops=60) # theoretical is nv_gbs=1008, amd_gbs=960 def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750) - def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750) + def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=820, amd_gbs=750) if __name__ == '__main__': unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7af6294c83..9a505a3921 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase): def _test_no_nested_ranges(self, lins, skip=None): for l in lins: range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG]) - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)] for i,u in enumerate(ranges): if skip and i in skip: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase): # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE uops = get_program(ast, opts=opt).uops begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] - end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] + end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0] for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) for u in uops: if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG: @@ -214,8 +214,8 @@ class TestLinearizer(unittest.TestCase): else: assert u.src[1].op in GroupOp.ALU assert begin_range < uops.index(u) < end_range - # children of STORE are placed after ENDRANGE - if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src): + # children of END are placed after ENDRANGE + if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src): assert end_range < uops.index(u) def test_grouped_dims(self): @@ -400,7 +400,7 @@ class TestLinearizer(unittest.TestCase): # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) - assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1 + assert len([u for u in uops if u.op is Ops.IF and u.src[1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") diff --git a/test/test_ops.py b/test/test_ops.py index fb3869a295..022131c50d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2602,6 +2602,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)), lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) + @unittest.skipIf(Device.DEFAULT == "AMD" and CI, "remu failure?") def test_avg_pool3d_failure(self): with Context(NOOPT=0): helper_test_op([(1,1,16,16,16)], diff --git a/test/test_rangeify.py b/test/test_rangeify.py index d0a4eea1c1..9bed5c1481 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,7 +1,9 @@ import unittest -from tinygrad import Tensor, nn -from tinygrad.helpers import Context, GlobalCounters, CI, CPU_LVP, getenv +from tinygrad import Tensor, nn, Device +from tinygrad.helpers import Context, GlobalCounters, CI, getenv from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops +from tinygrad.renderer.ptx import PTXRenderer +from tinygrad.renderer.nir import NIRRenderer class TestRangeifyAssign(unittest.TestCase): def test_assign_permuted(self): @@ -40,7 +42,7 @@ elif getenv("BIG") > 0: else: BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 -@unittest.skipIf(CPU_LVP, "broken in LVP") +@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX") class TestPcontig(unittest.TestCase): def test_flash_attention_bw(self): def fa_bw(): @@ -62,7 +64,7 @@ class TestPcontig(unittest.TestCase): Tensor.realize(*ret) return ret - with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2): + with Context(PCONTIG=2, DEBUG=2): grads = fa_bw() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") diff --git a/test/test_tensor_uop.py b/test/test_tensor_uop.py index 0a526ef5a1..21dfe41b57 100644 --- a/test/test_tensor_uop.py +++ b/test/test_tensor_uop.py @@ -3,7 +3,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes from tinygrad.engine.realize import run_schedule -from tinygrad.uop.ops import Ops, UOp, UPat +from tinygrad.uop.ops import UOp from tinygrad.helpers import SPLIT_REDUCEOP class TestTensorUOp(unittest.TestCase): @@ -93,7 +93,6 @@ class TestTensorUOp(unittest.TestCase): out.realize() self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist()) -reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, allow_any_len=True, src=(UPat(), UPat((Ops.REDUCE_AXIS, Ops.REDUCE)))))) @unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP") class TestReduceOp(unittest.TestCase): def test_no_split_reduce_kernel(self): @@ -101,23 +100,18 @@ class TestReduceOp(unittest.TestCase): a = a.sum() sched = a.schedule() assert len(sched) == 1 - assert reduce_kernel.match(sched[0].ast, {}) def test_split_reduce_kernel_dim0(self): a = Tensor.rand(256, 255).realize() a = a.sum() sched = a.schedule() assert len(sched) == 2 - for s in sched: - assert reduce_kernel.match(s.ast, {}) def test_split_reduce_kernel_dim1(self): a = Tensor.rand(255, 256).realize() a = a.sum() sched = a.schedule() assert len(sched) == 2 - for s in sched: - assert reduce_kernel.match(s.ast, {}) if __name__ == "__main__": unittest.main() diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d5d75462f3..a38ef37af9 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -665,19 +665,6 @@ class TestUOpGraph(unittest.TestCase): bad_gate = UOp.const(dtypes.int, 1) with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) - def test_switched_range_order(self): - glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - cf = UOp.const(dtypes.float, 0.0) - r1 = UOp.range(2, 0) - r2 = UOp.range(2, 1) - alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) - store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) - uops = to_uops_list([store]) - ranges = [x for x in uops if x.op is Ops.RANGE] - endranges = [x for x in uops if x.op is Ops.ENDRANGE] - # ranges are closed in the right order - self.assertEqual(endranges[-1].src[0], ranges[0]) - @track_rewrites() def expander_rewrite(sink): return graph_rewrite(sink, sym + expander) @@ -845,8 +832,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) def test_expand_ifs_one_gate(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) @@ -863,8 +848,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) # this will be fixed with the merge gated stores bounty @unittest.expectedFailure @@ -879,8 +862,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) class TestUOpTags(unittest.TestCase): def test_inc_by_one(self): diff --git a/test/unit/test_block_reorder.py b/test/unit/test_block_reorder.py deleted file mode 100644 index e81b1ff6c5..0000000000 --- a/test/unit/test_block_reorder.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest, random -from tinygrad.dtype import dtypes -from tinygrad.uop.ops import print_uops, UOp, Ops -from tinygrad.codegen.late.linearize import block_reorder -from tinygrad.renderer.cstyle import OpenCLRenderer - -def is_toposorted(lst:list[UOp]): - seen = set() - for u in lst: - if any(p not in seen for p in u.src): return False - seen.add(u) - return True - -class TestBlockReorder(unittest.TestCase): - def _test_randomize(self, golden:list[UOp]): - # test random order is always same - for _ in range(50): - # shuffle and form a valid toposort - lst = golden[:] - random.shuffle(lst) - topolst = [] - for u in lst: - for p in u.toposort(): - if p not in topolst: topolst.append(p) - assert is_toposorted(topolst) - - for x,y in zip(golden, this_order:=block_reorder(topolst)): - if x is not y: - print_uops(golden) - print_uops(this_order) - self.assertIs(x, y) - - def _test_render(self, golden:list[UOp]): - return OpenCLRenderer().render(golden) - - def test_loads(self): - a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0) - b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1) - c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2) - v1 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx0") - v2 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx1") - v1 = v1*27 - v2 = v2*4 - loads = [ - a.index(v1).load(dtype=dtypes.float), - a.index(v1+1).load(dtype=dtypes.float), - a.index(v1+2).load(dtype=dtypes.float), - a.index(v1+3).load(dtype=dtypes.float), - b.index(v2).load(dtype=dtypes.float), - b.index(v2+1).load(dtype=dtypes.float), - b.index(v2+2).load(dtype=dtypes.float), - b.index(v2+3).load(dtype=dtypes.float)] - #random.shuffle(loads) - sink = c.store(sum(loads)).sink() - - # determine golden order - golden = block_reorder(list(sink.toposort())) - - # render for test - print(self._test_render(golden)) - #print_uops(golden) - - # assert the loads are in this order - self.assertListEqual([g.src[0].src[1].render() for g in golden if g.op is Ops.LOAD], - ['(gidx1*4)', '((gidx1*4)+1)', '((gidx1*4)+2)', '((gidx1*4)+3)', - '(gidx0*27)', '((gidx0*27)+1)', '((gidx0*27)+2)', '((gidx0*27)+3)']) - - # assert math is after loads - first_math = [i for i,g in enumerate(golden) if g.op is Ops.ADD and g.dtype == dtypes.float][0] - assert not any(x.op is Ops.LOAD for x in golden[first_math:]) - - # confirm the sort is stable - self._test_randomize(golden) - -if __name__ == '__main__': - unittest.main() diff --git a/test/unit/test_kernelize.py b/test/unit/test_kernelize.py index e571c1d297..3cc0b0c0cc 100644 --- a/test/unit/test_kernelize.py +++ b/test/unit/test_kernelize.py @@ -20,8 +20,8 @@ class TestKernelize(unittest.TestCase): self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2) self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS) # input Tensor and user contiguous kernelize - self.assertIs(a0.uop.base.op, Ops.ASSIGN) - self.assertIs(a.uop.base.op, Ops.ASSIGN) + self.assertIs(a0.uop.base.op, Ops.AFTER) + self.assertIs(a.uop.base.op, Ops.AFTER) def test_two_reduce_w_add(self): a = Tensor.ones(16,16).contiguous() @@ -31,7 +31,7 @@ class TestKernelize(unittest.TestCase): # NOTE: the +1 is fused with a1, so a1 is not kernelized self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS) # the input to the REDUCE_AXIS is an ASSIGN though - self.assertIs(a1.uop.base.src[0].base.op, Ops.ASSIGN) + self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 7f3790c217..619d10e5ca 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.symbolic import simplify_valid from tinygrad.helpers import Context -from .test_uop_symbolic import check_uop_against_string +from test.unit.test_uop_symbolic import check_uop_against_string def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 9810de38d4..4edee1323b 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -148,6 +148,14 @@ class TestViz(BaseTestViz): a2 = uop_to_json(a)[id(a)] self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')") + def test_colored_label_multiline(self): + arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta") + src = [Tensor.empty(1).uop for _ in range(10)] + a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg) + exec_rewrite(a, [PatternMatcher([])]) + a2 = next(get_viz_details(0, 0))["graph"][id(a)] + self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw") + def test_inf_loop(self): a = UOp.variable('a', 0, 10, dtype=dtypes.int) b = a.replace(op=Ops.CONST) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 155ed805c2..bd3d771e2a 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -14,10 +14,10 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render -from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.opt.postrange import pm_postrange_opt from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen +from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize @dataclass class RewriteStep: @@ -30,12 +30,6 @@ class RewriteStep: def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) -rewrites_for_linearizer = [ - RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True), - RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"), - RewriteStep(block_merge, name="Linearizer: Merge Blocks"), - RewriteStep(pm_finalize, name="Linearizer: Finalize")] - def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]: # cache with the values of the context vars return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) @@ -101,11 +95,15 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q pm_final_rewrite = pm_decomp+pm_render+extra_matcher ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite")) - # return the list (with optional linearizer) - return ret + (rewrites_for_linearizer if linearizer else []) + # this was the linearizer + ret.append(RewriteStep(pm_merge_ends, name="merge ends")) + ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True)) -def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp: - return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer)) + # return the list + return ret + +def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp: + return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize)) def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: """ @@ -119,6 +117,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: Linear program in UOps. """ - lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst) + lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None)) if __debug__: type_verify(lst) return lst diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 5f406f78b0..15a82d2df9 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,4 +1,4 @@ -import math +import math, functools, operator from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop from tinygrad.helpers import all_int, dedup, get_contraction from tinygrad.dtype import dtypes @@ -87,7 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp): except ValueError: continue return s.substitute(subs) +def add_barrier_and_if(buf:UOp, e:UOp): + # TODO: this is not generic + local_ranges = [x for x in e.ended_ranges if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] + if len(local_ranges) == 0: return None + return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), e.barrier()))) + pm_add_gpudims = PatternMatcher([ # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), + # add barrier and if + (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.END, name="e"))), add_barrier_and_if), ]) diff --git a/tinygrad/codegen/late/control_flow.py b/tinygrad/codegen/late/control_flow.py new file mode 100644 index 0000000000..3eb9e56931 --- /dev/null +++ b/tinygrad/codegen/late/control_flow.py @@ -0,0 +1,102 @@ +import heapq +from collections import defaultdict +from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat + +def linearize(u:UOp) -> list[UOp]: + lst = list(u.toposort()) + in_this_block = set(lst) + local_children: defaultdict[UOp, list[UOp]] = defaultdict(list) + in_degree:dict[UOp, int] = {} + priorities:dict[UOp, int] = {} + + # get local children and assign priorities + # NOTE: this requires the lst be locally toposorted + for u in reversed(lst): + in_degree[u] = 0 + for s in u.src: + if s in in_this_block: + local_children[s].append(u) + in_degree[u] += 1 + # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too + priority = [0] + [priorities[x] for x in local_children[u]] + if u.op is Ops.LOAD: priority.append(-1000) + if u.op is Ops.BARRIER: priority.append(-1500) + # ranges are scheduled as late as possible so anything that can be outside is + #if u.op is Ops.RANGE: priority = [2000] + # move defines and consts to the top + if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000) + priorities[u] = min(priority) + + # number the uops in "ideal" order + nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))} + + # then force then to be toposorted in as close to the ideal order as possible + heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0]) + newlst = [] + while heap: + newlst.append(u:=heapq.heappop(heap)[1]) + for v in local_children[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v)) + + assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}" + return newlst + +class CFGContext: + def __init__(self, sink:UOp): + # there are 3 relationships between ranges: + # nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y + # dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y + # independent, endrange y is not a dependency of endrange x + # everything is nested inside the sink + deps: dict[UOp, set[UOp]] = {} + nesting: dict[UOp, UOp] = {} + for u in sink.toposort(): + deps[u] = set().union(*(deps[s] for s in u.src)) + if u.op in (Ops.END, Ops.ENDIF, Ops.SINK): + nesting |= {x:u for x in deps[u] if x.op in (Ops.END, Ops.ENDIF) and (u.op is Ops.SINK or u.src[0] in deps[x]) and x not in nesting} + if u.op in (Ops.RANGE, Ops.END, Ops.IF, Ops.ENDIF): deps[u] |= {u} + + self.edges: dict[UOp, UOp] = {} + siblings: dict[UOp, list[UOp]] = {} + for k,vv in nesting.items(): siblings.setdefault(vv, []).append(k) + for k,v in siblings.items(): + # range/if that have dependencies on other siblings need to run after them + order = sorted(v, key=lambda x: len(deps[x].intersection(v))) + zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[0]] + order, order) + for x,y in zipped: + # TODO: is this check correct? + if y.src[0] not in x.backward_slice_with_self: + self.edges[y.src[0]] = x + +pm_add_control_flow = PatternMatcher([ + (UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None), +]) + +def do_merge_ends(s:UOp): + # NOTE: this can fail + stacked: dict[UOp, list[UOp]] = {} + dangling_ifs = [] + for x in s.toposort(): + if x.op in {Ops.END, Ops.ENDIF}: + assert x.op is not Ops.END or x.arg == 1, "ends must be single ends for linearizer" + stacked.setdefault(x.src[0], []).append(x) + if x.op is Ops.IF: dangling_ifs.append(x) + dangling_ifs = [x for x in dangling_ifs if x not in stacked] + replaces = {} + for k,v in stacked.items(): + if len(v) == 1: continue + rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=x[0].arg) + for x in v: replaces[x] = rep + if not len(replaces) and not len(dangling_ifs): return None + ret = s.substitute(replaces) + if len(dangling_ifs): + assert len(dangling_ifs) == 1, "we only support 1 dangling if" + ret = ret.replace(src=(UOp(Ops.ENDIF, src=(dangling_ifs[0], *ret.src)),)) + return ret + +pm_merge_ends = PatternMatcher([ + # for renderering and linearizing, all ends must end one loop + (UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None), + (UPat(Ops.SINK, name="s"), do_merge_ends), +]) \ No newline at end of file diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index c0012b73ff..7ada724c99 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -123,7 +123,7 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp): return gep.src[0].store(st.gep(new_arg), *sto.src[2:]) load_store_folding = PatternMatcher([ - (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index), + (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines).or_after(name="buf")), UPat.var("vec"))), expand_index), # GEP after LOAD (UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True), lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)), @@ -242,11 +242,13 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt)))) devectorize = PatternMatcher([ + # CAST after AFTER + (UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)), # no ALU on vectorized dtypes (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), - (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), ]) pm_render = PatternMatcher([ @@ -266,9 +268,9 @@ pm_render = PatternMatcher([ UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)), (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),), allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)), - # gate any stores that aren't gated with ifs + # gate any stores that aren't gated with if/endif pairs (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), - lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \ + lambda store,idx: UOp(Ops.ENDIF, src=(uif:=UOp(Ops.IF, src=(idx.src[2],)), UOp(Ops.STORE, src=store.src[:2]+(uif,)+store.src[2:]))) if \ len(store.src) <= 2 or store.src[2].op != Ops.IF else None), ]) @@ -293,15 +295,17 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): # if we have a range if len(reduce_range) != 0: topo = inp.toposort() - stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE]) - input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges]) + ended_ranges = flatten([x.src[:x.arg] for x in topo if x.op is Ops.END]) + input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges]) identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar())) - acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0)) - do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity) - lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element + acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)) + acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \ + acc.index(UOp.const(dtypes.int, 0)).store(identity) + lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) - return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret + if len(reduce_range) == 0: return ret + return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(ends=reduce_range[::-1])).index(UOp.const(dtypes.int, 0)).load() pm_reduce = PatternMatcher([ # REDUCE -> DEFINE_ACC+ASSIGN diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 9a42d414ce..1f270394e6 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -87,7 +87,7 @@ expander = PatternMatcher([ lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE, - Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), + Ops.VECTORIZE, Ops.IF, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), # BARRIERs aren't actually expanded (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)), @@ -145,8 +145,7 @@ def fix_group_for_reduce(x:UOp): reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) - # gate with an if on the store + do the final reduce - buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf)) + # do the final reduce (if/barrier are added in gpudims step) return buf.reduce(*reduce_loop, arg=x.arg) pm_pre_expander = PatternMatcher([ diff --git a/tinygrad/codegen/late/linearize.py b/tinygrad/codegen/late/linearize.py deleted file mode 100644 index d860125adf..0000000000 --- a/tinygrad/codegen/late/linearize.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import annotations -import heapq -from collections import defaultdict -from dataclasses import dataclass, replace -from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate -from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER - -# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed -def block_reorder(lst:list[UOp]) -> list[UOp]: - in_this_block = set(lst) - local_children: defaultdict[UOp, list[UOp]] = defaultdict(list) - in_degree:dict[UOp, int] = {} - priorities:dict[UOp, int] = {} - - # get local children and assign priorities - # NOTE: this requires the lst be locally toposorted - for u in reversed(lst): - in_degree[u] = 0 - for s in u.src: - if s in in_this_block: - local_children[s].append(u) - in_degree[u] += 1 - # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too - priority = [0] + [priorities[x] for x in local_children[u]] - if u.op is Ops.LOAD: priority.append(-1000) - if u.op is Ops.BARRIER: priority.append(-1500) - priorities[u] = min(priority) - - # number the uops in "ideal" order - nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))} - - # then force then to be toposorted in as close to the ideal order as possible - heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0]) - newlst = [] - while heap: - newlst.append(u:=heapq.heappop(heap)[1]) - for v in local_children[u]: - in_degree[v] -= 1 - if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v)) - - assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}" - return newlst - -# ***** basic block ***** - -def disp(y:UOp) -> str: - if y.op is Ops.IF: return f'IF{id(y)}' - if y.op is Ops.RANGE: return str(y.arg) - return "" - -@dataclass(frozen=True, eq=False) -class BasicBlock: - lst: tuple[UOp, ...] - ctx: tuple[UOp, ...] = () - end: UOp|None = None - cnt: int = 0 - child_ctx: tuple[UOp, ...]|None = None - def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks") - def __repr__(self): - return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\ - f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\ - f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst]) - def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx - -def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize)) - -# ***** block context ***** - -@dataclass -class BlockContext: - child_count: dict[UOp, int] - block_ctxs: dict[UOp, tuple[UOp, ...]] - child_ctxs: dict[UOp, tuple[UOp, ...]] - def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u]) - @staticmethod - def from_sink(sink:UOp) -> BlockContext: - # get children and all block contexts - ctx = BlockContext({}, {}, {}) - for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL): - this_block_ctx: list[UOp] = [] - ctx.child_count[u] = 0 - - # get children and accumulate the last_ctx - for s in u.src: - if s.op is Ops.SPECIAL: continue - # NOTE: if a parent appears multiple times in the src, it counts multiple times as a child - ctx.child_count[s] += 1 - this_block_ctx += ctx.last_ctx(s) - - # save the block ctx. SINK never has anything - ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else () - - # RANGE/IF add to the next ctx - # STORE/ASSIGN subtract from the next ctx - if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,)) - elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src]) - return ctx - -# ***** make blocks ***** - -DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST} - -def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp: - ends_to_add = [z for z in new_ctx if z not in current_ctx] - while len(ends_to_add): - r:UOp = ends_to_add.pop(-1) - new_ctx = tuple([z for z in new_ctx if z is not r]) - end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)) - base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) - return base_block - -def make_block_bottom_up(ctx:BlockContext, x:UOp): - if x.op is Ops.BLOCKSTART: - current_ctx, child_ctx = x.arg - lst = list(x.src) - child_count = 1 - else: - current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None) - lst = [x] - - # count of times we've seen this block, or a seed for a new block if we can't merge it - unmergable: defaultdict[UOp, int] = defaultdict(int) - blockseeds = defaultdict(list) - - # add the srcs of this to the frontier - # NOTE: things may be in here multiple times, that's okay - frontier_nodes = list(flatten(y.src[::-1] for y in lst)) - while len(frontier_nodes): - u = frontier_nodes.pop(0) - if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1: - # count is correct - if (newctx:=ctx.block_ctxs[u]) == current_ctx: - # block has same context, merge it, and put the srcs on the frontier - lst.append(u) - frontier_nodes.extend(u.src[::-1]) - else: - # block has different context, add it to blockseeds - blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u) - del unmergable[u] - else: - # count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable - unmergable[u] += 1 - - # add unmergables to sources - srcs = [] - for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt - - # add blockseeds, with blockends as needed - for (new_ctx, new_child_ctx), v in blockseeds.items(): - base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx)) - srcs.append(add_blockends(base_block, new_ctx, current_ctx)) - - lst = lst[::-1] - if BLOCK_REORDER: lst = block_reorder(lst) - bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) - return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) - -# we prevent the source of the SPECIAL from being linearized since its not part of the kernel -def raise_bottom_up_gate(): raise BottomUpGate() - -block_create = PatternMatcher([ - (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up), - (UPat(Ops.SPECIAL), raise_bottom_up_gate) -]) - -# ***** blockend merging **** - -def merge_blockends(sink:UOp) -> UOp|None: - # only run on the final BLOCK with the SINK in it - if sink.arg.lst[-1].op is not Ops.SINK: return None - # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs - blockends_to_arg: dict[UOp, list[UOp]] = {} - for be in sink.toposort(): - if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) - new_forks = {} - for k,v in blockends_to_arg.items(): - # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails - if len(v) > 1: - bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) - out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb) - # NOTE: bb.ctx != u.arg.ctx can cause problems here - for u in v: new_forks[u] = out - if len(new_forks) == 0: return None - return sink.substitute(new_forks) - -pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)]) - -# ***** block merging **** - -def merge_block(x:UOp): - unmergable_blocks, mergable_blocks = [], [] - mergable_dict: defaultdict[UOp, int] = defaultdict(int) - for y in x.src: - if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1 - elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1 - else: unmergable_blocks.append(y) - for k,v in mergable_dict.items(): - if v == k.arg.cnt: mergable_blocks.append(k) - else: unmergable_blocks.extend([k]*v) - if len(mergable_blocks) == 0: return None - del mergable_dict - - # create the block - arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst) - return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg) - -def remove_blockend(x:UOp): - # if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it - if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None - - if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]): - assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})" - parent_block = parent_blocks[0] - assert len(parent_blocks) == parent_block.arg.cnt - # NOTE: DEFINE_ACC doesn't have to be handled in any special way - late_ops = list(x.arg.lst) - # NOTE: we have to add a barrier at the start if barrier is used in the range - if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE: - late_ops = [UOp(Ops.BARRIER)] + late_ops - # peephole opt, remove any BARRIERs next to each other - for i in range(len(late_ops)-1): - if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP) - arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt) - return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg) - # else the whole context ended by the blockend is already in this block and we can safely turn it into a block - return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)) - -block_merge = PatternMatcher([ - (UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block), - (UPat(Ops.BLOCKEND, name="x"), remove_blockend), -]) - -# ****** finalize ****** - -def finalize(sink:UOp) -> UOp: - if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): - raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}") - - # place the early things - lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst) - return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst))) - -pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)]) diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 55a443dfdf..720237b3b0 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -4,8 +4,8 @@ from collections import defaultdict from typing import cast, Final from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp from tinygrad.device import Buffer -from tinygrad.dtype import AddrSpace, dtypes, ImageDType -from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten +from tinygrad.dtype import dtypes, ImageDType +from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -64,21 +64,8 @@ class Scheduler: return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1) def _globalizable_rngs(self) -> list[UOp]: - store_rngs = self.ast.src[0].src[2:] - - # filter any not in local stores - local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \ - or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)] - for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) - - # filter any not in reduces - # TODO: enable this - """ - reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE] - for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) - """ - - return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[-1] == AxisType.LOOP] if store_rngs else [] + # all ranges that end before any STOREs + return [x for x in self.ast.toposort(lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in self.ast.ranges] def convert_loop_to_global(self): if not self.opts.has_local: return None @@ -89,11 +76,11 @@ class Scheduler: self.ast = self.ast.substitute(dict(zip(self.rngs, rng))) def colors(self) -> list[str]: - store_rngs = flatten([x.src[2:] for x in self.ast.src]) + globalizible_rngs = self._globalizable_rngs() ret = [] for x,r in zip(self.axis_types, self.rngs): if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE") - elif r not in store_rngs and x == AxisType.LOOP: ret.append("BLACK") + elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK") else: ret.append(axis_colors[x]) return ret def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())]) @@ -348,7 +335,7 @@ def apply_opts(ctx:Renderer, ast:UOp): elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet - if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD): + if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 21cce836f3..bb87c103b9 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -156,7 +156,9 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr if lib in seen_libs: continue # filter out kernels that use 1000x more compute than the smallest least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) - if least_compute_ops*1000 < this_compute_ops: continue + if least_compute_ops*1000 < this_compute_ops: + if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}") + continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches')) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 9053df5eaa..eeb84071eb 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -13,14 +13,16 @@ def flatten_range(r:UOp): pm_flatten_range = PatternMatcher([ # real ranges only (UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range), + # END is only on RANGES. TODO: this is copied from symbolic + (UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))), ]) def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}]) def simplify_merge_adjacent(u:UOp) -> UOp|None: reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE] - i = range_start[u.op] - while i < len(u.src)-1: - r0, r1 = u.src[i], u.src[i+1] + i = 0 + while i < len(u.ended_ranges)-1: + r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1] # check same type if r0.arg[-1] == r1.arg[-1]: # check if the ranges to merge are in the same reduces @@ -39,7 +41,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: return u pm_simplify_ranges = PatternMatcher([ - (UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent), + (UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent), ]) def mark_range_mod(ctx, r:UOp, c:UOp): @@ -57,7 +59,7 @@ def do_substitute(ctx, x: UOp): def dont_sub_ranges_for_image(ctx, x:UOp): if isinstance(x.src[0].dtype, ImageDType): - for s in x.src[1:]: ctx[s] = None + for s in x.src[0].ranges: ctx[s] = None pm_split_ranges = PatternMatcher([ (UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod), diff --git a/tinygrad/device.py b/tinygrad/device.py index 7db5310bf8..2e4b1f2520 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup -from tinygrad.helpers import unwrap_class_type +from tinygrad.helpers import unwrap_class_type, suppress_finalizing from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -163,6 +163,7 @@ class Buffer: return self._trace_num @property def nbytes(self): return self.size*self.dtype.itemsize + @suppress_finalizing def __del__(self): (not hasattr(self, '_buf')) or self.deallocate() def __repr__(self): return f" tuple[list[ScheduleItem], dict[ in_degree: dict[UOp, int] = {} var_vals: dict[str, int] = {} for u in sched_sink.toposort(): - if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip + if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] in_degree.setdefault(k, 0) for s in k.src: - if s.op is Ops.ASSIGN: + if s.op is Ops.AFTER: children[s.src[1]].append(k) in_degree[k] += 1 elif s.op in {Ops.MSELECT, Ops.MSTACK}: for ss in s.src: if ss.op is Ops.MSELECT: ss = ss.src[0] if ss.op is not Ops.BUFFER: - assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}" + assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}" children[ss.src[1]].append(k) in_degree[k] += 1 elif s.op is Ops.BUFFER: @@ -43,7 +43,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" var_vals[var.expr] = val else: - raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}") + raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") # linearize KERNEL UOps into ScheduleItems in BFS order diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 9cacc65b4b..2a8739f3b6 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -170,9 +170,9 @@ SPEC = ContextVar("SPEC", 0) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify -REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0) DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0) + @dataclass(frozen=True) class Metadata: name: str diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index b27ab036c0..c8884146d3 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -36,7 +36,7 @@ class BatchNorm: self.weight: Tensor|None = Tensor.ones(sz) if affine else None self.bias: Tensor|None = Tensor.zeros(sz) if affine else None - self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False) + self.num_batches_tracked = Tensor.zeros(dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False) if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]: diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 87ddce695a..b70b51c012 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -28,9 +28,12 @@ class Estimates: mult_stack: list[sint] = [] dont_count: set[UOp] = set() if ignore_indexing: + def range_gate(x): return x.op is not Ops.RANGE for u in uops: if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): - dont_count = dont_count.union(u.src[0].toposort()) + # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER + dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) + # TODO: is this correct? this all needs to be cleaned up if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) elif u.op is Ops.IF: dont_count = dont_count.union(u.src[0].toposort()) @@ -45,7 +48,7 @@ class Estimates: mults *= cast(sint, u.src[0].ssimplify()) # SPECIAL are already counted in mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults - elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) + elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): lds += u.dtype.itemsize * mults diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index c3a8e1508d..5afcd0711a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,7 +1,7 @@ from typing import Literal, Callable, cast import os, math, sys from collections import defaultdict, Counter -from tinygrad.codegen.opt import tc +from tinygrad.codegen.opt import tc, axis_letters from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate @@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), - (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), + (UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"), (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), # r method accesses (UPat(Ops.RANGE, name="x"), @@ -144,6 +144,9 @@ class CStyleLanguage(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + r[u] = r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue @@ -160,7 +163,7 @@ class CStyleLanguage(Renderer): # naming prefix = None if u.op is Ops.SPECIAL: r[u] = u.arg - elif u.op is Ops.RANGE: r[u] = "ridx"+range_str(u) + elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u) else: prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast", @@ -170,7 +173,7 @@ class CStyleLanguage(Renderer): l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" - if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 + if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 8be73d536f..b67bd9cb32 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([ f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n" f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n" f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"), - (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: + (UPat(Ops.END, name="x"), lambda ctx,x: f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n" f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n" f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n" @@ -167,6 +167,9 @@ class LLVMRenderer(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + r[u] = r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 26cf519c7d..eec9cade89 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -1,4 +1,4 @@ -from typing import Callable, cast +from typing import Callable, cast, Any from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes from tinygrad.helpers import DEBUG, OSX, unwrap from tinygrad.renderer import Renderer @@ -169,10 +169,13 @@ class NIRRenderer(Renderer): def render(self, uops:list[UOp]): self.prerender(uops) for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg - self.r, self.param_idx, ranges = {}, 0, [] + self.r: dict[UOp, Any] = {} + self.param_idx, ranges = 0, [] for u in uops: if u.op == Ops.NOOP or u.op == Ops.INDEX: pass + elif u.op is Ops.AFTER: + self.r[u] = self.r[u.src[0]] elif u.op == Ops.SINK: if u.arg is not None: self.b.shader.contents.info.name = mesa.char_pointer_cast(u.arg.function_name) elif u.op == Ops.DEFINE_LOCAL: @@ -183,7 +186,7 @@ class NIRRenderer(Renderer): nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) mesa.nir_push_loop(self.b) self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype) - elif u.op == Ops.ENDRANGE: + elif u.op == Ops.END: nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]), functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break)) mesa.nir_pop_loop(self.b, None) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 4695c880c3..565faf52b4 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([ if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]), - (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + (UPat(Ops.END, name="x", src=(UPat.var("src0"),), allow_any_len=True), lambda ctx, x, src0: [ ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), @@ -179,6 +179,9 @@ class PTXRenderer(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + self.r[u] = self.r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue @@ -216,7 +219,7 @@ class PTXRenderer(Renderer): [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)], [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] - prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), + prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]), Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype) diff --git a/tinygrad/runtime/autogen/mesa.py b/tinygrad/runtime/autogen/mesa.py index 78a0efc2e6..66cd9e5342 100644 --- a/tinygrad/runtime/autogen/mesa.py +++ b/tinygrad/runtime/autogen/mesa.py @@ -7,13 +7,14 @@ # LONGDOUBLE_SIZE is: 16 # import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers -def brew_prefix(): - try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip() - except Exception: return '' +def brew_path(nm): + try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib" + except Exception: return 'failed' PATHS_TO_TRY = [ (BASE:=os.getenv('MESA_PATH', f"/usr{'/local/' if helpers.OSX else '/'}lib"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so'), f'{BASE}/libtinymesa{EXT}', - f'{brew_prefix()}/lib/libtinymesa_cpu.dylib', + brew_path('tinymesa_cpu'), + brew_path('tinymesa'), ] def _try_dlopen_tinymesa_cpu(): library = ctypes.util.find_library("tinymesa_cpu") @@ -6087,7 +6088,7 @@ struct_nir_op_info._fields_ = [ nir_op_info = struct_nir_op_info try: nir_op_infos = (struct_nir_op_info * 489).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_op_infos') -except AttributeError: pass +except (AttributeError, ValueError): pass try: nir_op_is_selection = _libraries['FIXME_STUB'].nir_op_is_selection nir_op_is_selection.restype = ctypes.c_bool @@ -8118,7 +8119,7 @@ c__EA_nir_intrinsic_index_flag = ctypes.c_uint32 # enum nir_intrinsic_index_flag = c__EA_nir_intrinsic_index_flag nir_intrinsic_index_flag__enumvalues = c__EA_nir_intrinsic_index_flag__enumvalues try: nir_intrinsic_index_names = (ctypes.POINTER(ctypes.c_char) * 75).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_index_names') -except AttributeError: pass +except (AttributeError, ValueError): pass class struct_nir_intrinsic_instr(Structure): pass @@ -8242,7 +8243,7 @@ struct_nir_intrinsic_info._fields_ = [ nir_intrinsic_info = struct_nir_intrinsic_info try: nir_intrinsic_infos = (struct_nir_intrinsic_info * 732).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_infos') -except AttributeError: pass +except (AttributeError, ValueError): pass try: nir_intrinsic_src_components = _libraries['libtinymesa_cpu.so'].nir_intrinsic_src_components nir_intrinsic_src_components.restype = ctypes.c_uint32 @@ -19877,4 +19878,4 @@ __all__ = \ 'union_util_format_description_0', 'util_format_colorspace', 'util_format_layout', 'va_list'] lvp_nir_options = gzip.decompress(base64.b64decode('H4sIAAAAAAAAA2NgZGRkYGAAkYxgCsQFsxigwgwQBoxmhCqFq2WEKwIrAEGIkQxoAEMALwCqVsCiGUwLMHA0QPn29nBJkswHANb8YpH4AAAA')) -def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)') +def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)') diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 9dd145d299..9a8ade8e18 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -52,11 +52,11 @@ class PythonProgram: loop_ends: dict[int, int] = {} while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] - void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} + void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) - if uop is Ops.ENDRANGE: + if uop is Ops.END: loop_ends[idp[0]] = i i = idp[0] continue @@ -72,7 +72,8 @@ class PythonProgram: if g: _store(m, o+j, v, dtp[1].scalar()) i += 1 continue - if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: + if uop is Ops.AFTER: ul[i] = inp[0] + elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: assert isinstance(dtype, PtrDType), dtype storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index f9ec8d1062..04c2987180 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -58,7 +58,7 @@ class LLVMCompiler(Compiler): self.diag_msgs.append(msg) self.handle_diag = handle_diag llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None) - super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") + super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo) diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index 8f83c34657..7e8ff6150a 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -69,7 +69,9 @@ class PTXCompiler(Compiler): def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) class NVPTXCompiler(PTXCompiler): - def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx") + def __init__(self, arch:str): + nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint()))) + super().__init__(arch, cache_key="nv_ptx") def compile(self, src:str) -> bytes: jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle) jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "".encode()), handle) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index f5a7a999db..fb438489bb 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -52,11 +52,11 @@ class IndexingContext: def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.KERNEL}: return None - if x.op is Ops.ASSIGN and x.src[1].op is Ops.KERNEL: return None + if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None new_srcs = [] for s in x.src: new_src = s - if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): + if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: realized_ranges = ctx.realize_map[s] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 495922758c..1039e55385 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,10 +2,9 @@ from typing import cast from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo -from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType +from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate from tinygrad.uop.symbolic import symbolic_flat -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE -from tinygrad.helpers import DEBUG_RANGEIFY +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op @@ -137,6 +136,9 @@ def cleanup_dead_axes(b:UOp): # move the tag to the expand. NOTE: this expand tag might not survive return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag) +def gate_substitute(ctx, b:UOp) -> None: + if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate() +pm_gate_substitute = PatternMatcher([(UPat(GroupOp.All, name="b"), gate_substitute)], compiled=False) # if a buffer is being stored just for permutes or something, remove it # we want to reexpress the indexes of idx2 in terms of the implied b1 def remove_bufferize(src:UOp, buf:UOp, idx:UOp): @@ -179,11 +181,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): # if it makes it here, the bufferize is removed # this is the ranges replaced # NOTE: if buf src is a const, we don't replace it - if REAL_SUBSTITUTE: - return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}) - else: - replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]) - return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2])))) + return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute) def pre_bufferize(b:UOp, x:UOp, copy:UOp): nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) @@ -243,7 +241,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp): bufs: set[UOp] = set() def gate_input(u:UOp): # TODO: add cache to fix n^2 - if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) + if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) return not is_load root.toposort(gate=gate_input) @@ -278,7 +276,8 @@ def bufferize_to_store(x:UOp): assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" # in assign, this is the buffer size, not the bufferize size # TODO: assign_mops here - ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag) + do_store = assign_target.replace(dtype=sdtype).store(assign_src).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE]) + ret = assign_target.src[0].after(do_store) mops = [] walk = assign_mops while walk is not assign_mops.base: @@ -290,8 +289,8 @@ def bufferize_to_store(x:UOp): # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg.device, size, x.dtype) - ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag) - ret = ret.forced_reshape(shape) + do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE]) + ret = buf.after(do_store).forced_reshape(shape) # TODO: is this right? what if it's offset if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs]) @@ -302,9 +301,8 @@ def bufferize_to_store(x:UOp): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - # store has the other dtype here - # TODO: how is this unified? - return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape) + do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).end(ends=[x for x in rngs if x.op is Ops.RANGE]) + return buf.after(do_store).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), @@ -336,12 +334,13 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): ctx.vars[b] = None return b.src[0] -def handle_assign(ctx:LocalAddBufferContext, assign:UOp): - buf = assign.as_buf() +def handle_after(ctx:LocalAddBufferContext, after:UOp): + if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None + buf = after.as_buf() # HACK to put the buffer in the MAP instead of MSTACK/MSELECT if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] assert buf not in ctx.map - ctx.map[buf] = assign + ctx.map[buf] = after return buf def renumber_range(ctx:LocalAddBufferContext, r:UOp): @@ -351,7 +350,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp): return ret def find_bufs(x:UOp): - idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.ASSIGN) if s.op is Ops.INDEX] + idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX] read_from: dict[UOp, Ops] = {} if any((buf:=idx.as_buf()).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs): raise RuntimeError(f"cycle detected while indexing {buf}") @@ -360,7 +359,7 @@ to_define_global = PatternMatcher([ (UPat(Ops.STORE, name="x"), find_bufs), (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.BIND, name="b"), unbind_kernel), - (UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign), + (UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after), # HACK in case any CONSTs were replaced # this is only needed if you are using symbolic @@ -389,16 +388,9 @@ rangeify_codegen = PatternMatcher([ # add loads to non ptr indexes # TODO: this can be moved into codegen? - (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), - - # TODO: this can be moved into codegen - (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), - lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), - - # TODO: hack for group for reduce - (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), - lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), + (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) + .f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), ]) def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): @@ -419,9 +411,8 @@ class Kernel: ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op) return f"" -def split_store(ctx:list[UOp], x:UOp): +def split_store(ctx:list[UOp], x:UOp) -> UOp|None: if len(x.ranges): return None - if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None # local kernel rewrite lctx = LocalAddBufferContext() @@ -431,16 +422,22 @@ def split_store(ctx:list[UOp], x:UOp): metadatas = [ctx[y].metadata for y in lctx.parent_tags] # NOTE: the hack for COPY is here - ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \ - if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] + for u in ret.toposort(): + # TODO: this can be wrong if there's multiple of these + if u.op in {Ops.COPY, Ops.BUFFER_VIEW}: + ret = u + break + else: + ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) + kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") - return x.as_buf().assign(kernel) + return kernel split_kernels = PatternMatcher([ - (UPat(Ops.STORE, name="x"), split_store), + (UPat((Ops.STORE, Ops.END), name="x"), split_store), ]) def tag_uop(ctx:list[UOp], x:UOp): @@ -471,25 +468,6 @@ replace_contiguous = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -def do_sub_recurse(s:UOp): - x,keys,values = s.src[0], s.src[1].src, s.src[2].src - # SUBSTITUTE applied to SUBSTITUTE runs the child SUB on the parents. though this is probably wrong in the generic case - if x.op is Ops.SUBSTITUTE: - sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:]) - sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:]) - return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v)) - # here we actually do the SUBSTITUTE - if x in keys: return values[keys.index(x)] - # we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size) - x_ranges = x.ranges - new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)} - # if there's no SUBSTITUTEs left, we can just return x - if len(new_kv) == 0: return x - # then we add SUBSTITUTE to all parents - uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values())) - return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src])) -pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)]) - @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph") @@ -504,8 +482,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") - # TODO: can you substitute and remove costly buffers at the same time? - tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph @@ -524,13 +500,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} for u in tsink.toposort(): - if u.op is not Ops.ASSIGN: continue + if u.op is not Ops.AFTER: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: # TODO: this is probably broken for MSELECT/MSTACK if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue - if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()): - raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") + if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()): + raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1028519341..74edfc145f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -249,9 +249,9 @@ class Tensor(MathTrait): self.kernelize(*lst) sink = UOp.sink(*[x.uop for x in (self,)+lst]) - # remove all ASSIGNs, after scheduling, the tensors are just buffers - remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN} - _apply_map_to_tensors(remove_assign_map, name="Remove Assigns") + # remove all AFTERs, after scheduling, the tensors are just buffers + remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.AFTER} + _apply_map_to_tensors(remove_assign_map, name="Remove After") # create the schedule schedule, var_vals = create_schedule_with_vars(sink) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 2922fd4471..1c20da1185 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -12,19 +12,18 @@ class Ops(FastEnum): NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702 SENTINEL = auto() + # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] + AFTER = auto() + # buffer ops COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702 # create buffer BUFFERIZE = auto() - SUBSTITUTE = auto() # ops that adjust the behavior of the scheduler CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702 - # blocks in linearizer (only used there) - BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702 - # movement ops! these only exist in the tensor graph RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 MULTI = auto() # MULTI is really a movement op @@ -67,7 +66,7 @@ class Ops(FastEnum): WHERE = auto(); MULACC = auto() # noqa: E702 # control flow ops - BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702 + BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702 # consts. VCONST is a vectorized const VCONST = auto(); CONST = auto() # noqa: E702 @@ -91,7 +90,6 @@ class GroupOp: Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP} Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR} - Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART} # BinaryOps that can be flipped Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR} diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 55ee4a69c1..39dc77c6b7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def ptrdtype(self) -> PtrDType: - if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType") + if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}") return self.dtype # *** uop shape stuff *** @@ -179,7 +179,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST: return None # some ops init the shape @@ -190,7 +190,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE: return self.src[0]._shape + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER | Ops.END: + return self.src[0]._shape # ops with custom handling case Ops.KERNEL: return self.arg.ast._shape @@ -275,6 +276,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): for s in self.src[:range_start[self.op]]: ret.update(s.ranges) for s in UOp.sink(*self.src[range_start[self.op]:]).ranges: if s in ret: del ret[s] + elif self.op is Ops.END: + for s in self.src[self.arg:]: ret.update(s.ranges) + for s in UOp.sink(*self.src[:self.arg]).ranges: + if s in ret: del ret[s] else: for s in self.src: ret.update(s.ranges) return ret @@ -284,6 +289,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.RANGE: return {self:None} return self._ranges + @functools.cached_property + def ended_ranges(self): + match self.op: + case Ops.REDUCE: return self.src[1:] + case Ops.END: return self.src[:self.arg] + case _: raise RuntimeError(f"{self.op} doesn't end ranges") + # *** uop evaluation *** def simplify(self, tracked=False, full_symbolic=True): @@ -324,7 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, *srcs:UOp|None, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) - def __getitem__(self, idx): return self.index(idx) + def __getitem__(self, *idx): return self.index(*idx) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape) @@ -349,6 +361,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs) + def end(self, *src:UOp, ends:Sequence[UOp]): + if len(ends) == 0: return self + return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends)) + def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def alu(self, op, *src:UOp, **kwargs): @@ -525,6 +541,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _device(self) -> str|tuple[str, ...]|None: if self.op is Ops.DEVICE: return self.arg if self.op is Ops.BUFFERIZE: return self.arg.device + if self.op is Ops.AFTER: return self.src[0]._device if self.op is Ops.MSELECT: assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device" return self.src[0].device[self.arg] @@ -538,8 +555,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.BUFFER: return self if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src)) - assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" - return self.src[0].base + assert self.op is Ops.AFTER, f"must be AFTER {self.op}" + return self.src[0].buf_uop.base def as_buf(self) -> UOp: if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg) @@ -810,6 +827,8 @@ class UPat(MathTrait): @staticmethod def any(*src): return UPatAny(src=src) def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,))) + def or_after(self, name:str|None=None): + return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True)) @staticmethod @functools.cache @@ -836,7 +855,6 @@ class UPat(MathTrait): def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) def fuse(self): return self.alu(Ops.FUSE) def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs) - def or_broadcasted(self, **kwargs): return UPat.any(self, self.broadcast(**kwargs)) def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) @@ -1091,7 +1109,7 @@ class RewriteContext: new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) except BottomUpGate: # if the bpm matching raised a gate, we are done with this node and dont continue down the srcs - self.replace[n] = new_n + self.replace[n] = unwrap(test_n) continue stack.append((n, 1, new_n)) for x in reversed(new_n.src): @@ -1171,6 +1189,7 @@ pm_lower_index_dtype = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)), (UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"), lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))), + # TODO: this is only triggering if they are all casts, correct? (UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))), ]) def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 78be792353..9d1e9794d3 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -66,8 +66,8 @@ buffer_spec = PatternMatcher([ ]) assign_spec = PatternMatcher([ - # KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), + # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), @@ -111,6 +111,9 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # REDUCE with an outerworld range (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), + + # AFTER if things were kernelized + (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True) ]) # ***** uop type spec ***** @@ -156,12 +159,16 @@ spec = PatternMatcher([ (UPat(Ops.DEFINE_REG, src=()), lambda: True), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ - all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), + (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: + rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ + all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + # allow AFTER on buffers + (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), + # **** new style load/store **** # make sure all index dtypes have been lowered @@ -171,18 +178,14 @@ spec = PatternMatcher([ # INDEX is used in new style load/store # INDEX takes a - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True), - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - - # LOAD on STORE - (UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), # LOAD takes a (UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])), (UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index), - # STORE takes a - (UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store), + # STORE takes a (UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE @@ -193,7 +196,7 @@ spec = PatternMatcher([ (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), - (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), + (UPat(Ops.END, dtype=dtypes.void), lambda: True), # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), @@ -201,9 +204,8 @@ spec = PatternMatcher([ (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), - (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True), (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), @@ -234,9 +236,6 @@ full_spec = PatternMatcher([ # SENTINEL should never be in the graph (UPat(Ops.SENTINEL), lambda: False), - # allow any SUBSTITUTE - (UPat(Ops.SUBSTITUTE), lambda: True), - # Invalid must have type Index (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), # where on index in rhs position is fine @@ -269,7 +268,7 @@ full_spec = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True), # linearizer: outputs + intermediate KERNELs - (UPat((Ops.BLOCKSTART, Ops.BLOCK, Ops.BLOCKFINAL, Ops.BLOCKEND, Ops.KERNEL), dtype=dtypes.void), lambda: True), + (UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True), # allow index dtype on a restricted set of UOps (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE, @@ -283,6 +282,8 @@ full_spec = PatternMatcher([ (UPat(Ops.DEFINE_VAR), lambda: True), # reshape on STORE (UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True), + # allow any AFTER + (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), ])+tensor_uop_spec+spec # ***** uop helpers ***** diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index e9fec2ae9e..85fa1a804b 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -377,6 +377,13 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), + # only RANGE/IF/STORE/KERNEL have side effects + (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))), + # after with 1 src is just src[0] + (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), + # END is only on RANGES + (UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))), ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 34f68ce448..9a11acf95e 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -2,6 +2,7 @@ tinygrad viz + @@ -38,24 +39,37 @@ ::-webkit-scrollbar-thumb { background: #686977; } a { color: #4a90e2; + text-decoration: underline; + cursor: pointer; } ul { padding: 0; - opacity: 0.6; white-space: nowrap; cursor: pointer; } - ul.active { + ul > p { + opacity: 0.6; + } + ul.active > p { opacity: 1; } ul > ul { display: none; + margin-left: 6px; + } + ul.has-children > p::before { + content:"▸ "; + } + ul.has-children.expanded > p::before { + content:"▾ "; } ul.expanded > ul { display: block; } - ul.disabled { + ul.disabled > p { opacity: 0.4; + } + ul.disabled { pointer-events: none; } label { @@ -137,7 +151,7 @@ .metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * { margin-top: 12px; } - .ctx-list > ul > * + * { + ul > * + * { margin-top: 4px; } .graph { diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 90ad10b6be..42fc837349 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) { if (parents == null && children == null) return; const src = [...parents, ...children, d.id]; nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id)); - const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; + const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath"); d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port"); e.stopPropagation(); @@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) { }).selectAll("text").data(d => { const ret = [[]]; for (const { st, color } of parseColors(d.label, defaultColor="initial")) { - for (const [i, l] of st.split("\n").entries()) { - if (i > 0) ret.push([]); - ret.at(-1).push({ st:l, color }); - } + const lines = st.split("\n"); + ret.at(-1).push({ st:lines[0], color }); + for (let i=1; i d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan") @@ -174,10 +173,16 @@ function tabulate(rows) { return root; } -var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity; +var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity, shapeMetadata = new Map(); +function focusShape(shape) { + saveToHistory({ shape:focusedShape }); + focusedShape = shape?.key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel); + return document.querySelector(".metadata").replaceChildren(shapeMetadata.get(focusedShape) ?? ""); +} + async function renderProfiler() { displayGraph("profiler"); - d3.select(".metadata").node().replaceChildren(focusedShape?.html ?? ""); + d3.select(".metadata").node().replaceChildren(shapeMetadata.get(focusedShape) ?? ""); // layout once! if (data != null) return updateProgress({ start:false }); const profiler = d3.select(".profiler").html(""); @@ -243,15 +248,17 @@ async function renderProfiler() { } const html = document.createElement("div"); html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node()); + const argsDiv = document.createElement("div"); argsDiv.id = "args"; html.appendChild(document.createElement("br")); html.appendChild(argsDiv); if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info; if (shapeRef != null) { - const p = html.appendChild(document.createElement("p")); - p.innerText = "\nView Codegen Rewrite"; p.style.cursor = "pointer"; - p.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step); + const a = html.appendChild(document.createElement("a")); + a.innerText = "\nView codegen rewrite"; + a.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step); } // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; - const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), html, key, ...shapeRef }; + if (key != null) shapeMetadata.set(key, html); + const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, ...shapeRef }; if (e.key != null) shapeMap.set(e.key, arg); // offset y by depth shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor }); @@ -295,17 +302,30 @@ async function renderProfiler() { const rows = [["DType", dtype], ["Len", formatUnit(sz)], ["Size", formatUnit(nbytes, "B")], ["Lifetime", formatTime(dur)]]; if (users != null) rows.push(["Users", users.length]); const info = html.appendChild(tabulate(rows).node()); + const arg = {tooltipText:info.outerHTML, key:`${k}-${num}`}; for (let u=0; u focusShape(shape); + const args = shapeMetadata.get(shape.key).querySelector("#args"); + const bufArg = d3.create("p").text(`${bufInfo} ${rows[2][1]}`).style("cursor", "pointer").style("margin-top", "4px").on("click", () => { + const device = document.getElementById(k); + if (!isExpanded(device)) device.click(); + focusShape(arg); + }).node(); + bufArg.dataset.num = num; + let before = null; + for (const c of args.children) { if (+c.dataset.num > num) { before = c; break; } } + args.insertBefore(bufArg, before); } } - const arg = {tooltipText:info.outerHTML, html, key:`${k}-${num}`}; + shapeMetadata.set(arg.key, html) shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) }); } // generic polygon merger @@ -338,6 +358,7 @@ async function renderProfiler() { else if (tid === focusedDevice) { track.shapes = track.views[0]; offset += rescaleTrack(track, tid, 1/track.scaleFactor); } } data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null; + toggleCls(document.getElementById(focusedDevice), document.getElementById(newFocus), "expanded"); focusedDevice = newFocus; return resize(); }); @@ -396,7 +417,7 @@ async function renderProfiler() { lw += e.label[li].width; } } - if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push([p, pcolor]); } + if (focusedShape != null && e.arg?.key === focusedShape) { paths.push([p, pcolor]); } } } // draw axes @@ -463,15 +484,11 @@ async function renderProfiler() { } } - function focusShape(shape) { - focusedShape = shape; render(zoomLevel); - return document.querySelector(".metadata").replaceChildren(shape?.html ?? ""); - } canvas.addEventListener("click", e => { e.preventDefault(); const foundRect = findRectAtPosition(e.clientX, e.clientY); if (foundRect?.step != null && foundRect?.key == null) { return setCtxWithHistory(foundRect.ctx, foundRect.step); } - if (foundRect?.key != focusedShape?.key) { focusShape(foundRect); } + if (foundRect?.key != focusedShape) { focusShape(foundRect); } }); canvas.addEventListener("mousemove", e => { @@ -530,10 +547,10 @@ function codeBlock(st, language, { loc, wrap }={}) { return ret; } -function setActive(e) { - if (e == null) return; - e.classList.add("active"); - requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" })); +function toggleCls(prev, next, cls, value) { + prev?.classList.remove(cls); + next?.classList.toggle(cls, value ?? true); + requestAnimationFrame(() => next?.scrollIntoView({ behavior: "auto", block: "nearest" })); } // ** hljs extra definitions for UOps and float4 @@ -563,31 +580,41 @@ const evtSources = []; // context: collection of steps const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false}; function setState(ns) { - const { currentCtx:prevCtx, currentStep:prevStep } = state; + const { ctx:prevCtx, step:prevStep } = select(state.currentCtx, state.currentStep); Object.assign(state, ns); // update element styles if needed - document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps); - if (state.currentCtx !== prevCtx) { - document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded"); - setActive(document.getElementById(`ctx-${state.currentCtx}`)); - } - if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) { - document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active"); - setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`)); + const { ctx, step } = select(state.currentCtx, state.currentStep); + toggleCls(prevCtx, ctx, "expanded", state.expandSteps); + if (ctx?.id !== prevCtx?.id) toggleCls(prevCtx, ctx, "active"); + if (ctx?.id !== prevCtx?.id || step?.id !== prevStep?.id) { + toggleCls(prevStep, step, "active"); + // walk the tree back until all parents expanded so that the child is visible + let e = step; + while (e?.parentElement?.id.startsWith("step")) { + e.parentElement.classList.add("expanded"); + e = e.parentElement; + } } // re-render main(); } +const getSubrewrites = (ul) => ul.querySelectorAll(":scope > ul"); + +function saveToHistory(ns) { + // NOTE: browser does a structured clone, passing a mutable object is safe. + history.replaceState(ns, ""); + history.pushState(ns, ""); +} + // set a new context and keep the old one in browser history function setCtxWithHistory(newCtx, step=0) { - // NOTE: browser does a structured clone, passing a mutable object is safe. - history.replaceState(state, ""); - history.pushState(state, ""); + saveToHistory(state); setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 }); } window.addEventListener("popstate", (e) => { + if (e.state?.shape != null) return focusShape({ key:e.state?.shape }); if (e.state != null) setState(e.state); }); @@ -605,15 +632,25 @@ async function main() { p.onclick = () => { setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 }); } + const stack = []; let list = ul; for (const [j,u] of steps.entries()) { - const inner = ul.appendChild(document.createElement("ul")); - inner.id = `step-${i}-${j}`; - inner.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : ''); - inner.style.marginLeft = `${8*u.depth}px`; - inner.onclick = (e) => { + while (stack.length && stack.at(-1).depth >= u.depth) stack.pop(); + const list = stack.length > 0 ? stack.at(-1).li : ul; + u.li = list.appendChild(document.createElement("ul")); + u.li.id = `step-${i}-${j}`; + const p = u.li.appendChild(document.createElement("p")); + p.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : ''); + p.onclick = (e) => { e.stopPropagation(); - setState({ currentStep:j, currentCtx:i, currentRewrite:0 }); + const subrewrites = getSubrewrites(e.currentTarget.parentElement); + if (subrewrites.length) { e.currentTarget.parentElement.classList.toggle("expanded"); } + setState({ currentStep:j, currentCtx:i }); } + stack.push(u); + } + for (const l of ul.querySelectorAll("ul > ul > p")) { + const subrewrites = getSubrewrites(l.parentElement); + if (subrewrites.length > 0) { l.innerText += ` (${subrewrites.length})`; l.parentElement.classList.add("has-children"); } } } return setState({ currentCtx:-1 }); @@ -706,8 +743,9 @@ async function main() { rewriteList.className = "rewrite-list"; for (let s=0; s<=step.match_count; s++) { const ul = rewriteList.appendChild(document.createElement("ul")); - ul.innerText = s; ul.id = `rewrite-${s}`; + const p = ul.appendChild(document.createElement("p")); + p.innerText = s; ul.onclick = () => setState({ currentRewrite:s }); ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : ""; if (s > 0 && s === currentRewrite) { @@ -762,22 +800,32 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid // **** keyboard shortcuts +const select = (ctx, step) => ({ ctx:document.getElementById(`ctx-${ctx}`), step:document.getElementById(`step-${ctx}-${step}`) }); +const deselect = (element) => { + const parts = element?.id.split("-").map(Number); + return element?.id.startsWith("ctx") ? { ctx:parts[1], step:null } : element?.id.startsWith("step") ? {ctx:parts[1], step:parts[2]} : {}; +} +const isExpanded = (el) => el?.classList.contains("expanded"); + document.addEventListener("keydown", (event) => { const { currentCtx, currentStep, currentRewrite, expandSteps } = state; // up and down change the step or context from the list const changeStep = expandSteps && ctxs[currentCtx].steps?.length; + const { step, ctx } = select(currentCtx, currentStep); if (event.key == "ArrowUp") { event.preventDefault(); if (changeStep) { - return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) }); + let prev = deselect(step.previousElementSibling); + if (prev.step == null && isExpanded(step.parentElement)) prev = deselect(step.parentElement); + return prev.step != null && !isExpanded(step) && setState({ currentRewrite:0, currentStep:prev.step }); } return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false }); } if (event.key == "ArrowDown") { event.preventDefault(); if (changeStep) { - const totalUOps = ctxs[currentCtx].steps.length-1; - return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) }); + const next = deselect(isExpanded(step) ? step.children[1] : step.nextElementSibling); + return next.step != null && setState({ currentRewrite:0, currentStep:next.step }); } return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false }); } @@ -787,6 +835,7 @@ document.addEventListener("keydown", (event) => { if (currentCtx === -1) { return setState({ currentCtx:0, expandSteps:true }); } + if (expandSteps && getSubrewrites(step).length) return step.children[0].click(); return setState({ expandSteps:!expandSteps }); } // left and right go through rewrites in a single UOp diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index baa3656850..8ac090359e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -17,10 +17,10 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", - **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", - Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", + **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", + Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", - Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00"} + Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} # VIZ API @@ -33,8 +33,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]: steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), "query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)] if isinstance(k.ret, ProgramSpec): - steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"}) - steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"}) + steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src", "depth":0}) + steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm", "depth":0}) for key in k.keys: ref_map[key] = i ret.append({"name":k.display_name, "steps":steps}) return ret @@ -71,7 +71,7 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]: if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg) label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}" if u.dtype != dtypes.void: label += f"\n{u.dtype}" - for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src): + for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])): if x in excluded: arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}" label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "") @@ -82,6 +82,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]: label += f"\n{shape_to_str(u.shape)}" if u.op in {Ops.INDEX, Ops.BUFFERIZE}: label += f"\n{u.render()}" + if u.op is Ops.END: + label += "\n"+' '.join([f"{colored(u.src[i].arg[0], axis_colors[u.src[i].arg[-1]])}({u.src[i].vmax+1})" for i in range(u.arg)]) except Exception: label += "\n" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}" @@ -98,9 +100,11 @@ def _reconstruct(a:int): return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest) def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: - ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"}) - yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, - "diff":None, "upat":None} + next_sink = _reconstruct(ctx.sink) + # in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink) + ignore_indexing = trace.keys[i].display_name.startswith("Schedule") and not (ctx.name in {"kernel split"} or \ + any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,))) + yield {"graph":uop_to_json(next_sink, ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)