mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
Merge branch 'master' into multioutput
This commit is contained in:
2
.github/actions/setup-tinygrad/action.yml
vendored
2
.github/actions/setup-tinygrad/action.yml
vendored
@@ -302,4 +302,4 @@ runs:
|
||||
- name: Install mesa (macOS)
|
||||
if: inputs.mesa == 'true' && runner.os == 'macOS'
|
||||
shell: bash
|
||||
run: brew install sirhcm/tinymesa/tinymesa
|
||||
run: brew install sirhcm/tinymesa/tinymesa_cpu
|
||||
|
||||
6
.github/workflows/benchmark.yml
vendored
6
.github/workflows/benchmark.yml
vendored
@@ -626,11 +626,11 @@ jobs:
|
||||
- name: benchmark openpilot 0.9.9 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: openpilot compile3 0.10.1 driving_policy
|
||||
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: benchmark MobileNetV2 on DSP
|
||||
run: |
|
||||
# generate quantized weights
|
||||
|
||||
@@ -520,17 +520,17 @@ generate_mesa() {
|
||||
LVP_NIR_OPTIONS=$(./extra/mesa/lvp_nir_options.sh $MESA_SRC)
|
||||
|
||||
fixup $BASE/mesa.py
|
||||
patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "f'{brew_prefix()}/lib/libtinymesa_cpu.dylib'"
|
||||
patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "brew_path('tinymesa_cpu')" "brew_path('tinymesa')"
|
||||
echo "lvp_nir_options = gzip.decompress(base64.b64decode('$LVP_NIR_OPTIONS'))" >> $BASE/mesa.py
|
||||
cat <<EOF | sed -i "/import ctypes.*/r /dev/stdin" $BASE/mesa.py
|
||||
def brew_prefix():
|
||||
try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip()
|
||||
except Exception: return ''
|
||||
def brew_path(nm):
|
||||
try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib"
|
||||
except Exception: return 'failed'
|
||||
EOF
|
||||
sed -i "/in_dll/s/.*/try: &\nexcept AttributeError: pass/" $BASE/mesa.py
|
||||
sed -i "/in_dll/s/.*/try: &\nexcept (AttributeError, ValueError): pass/" $BASE/mesa.py
|
||||
sed -i "s/import ctypes/import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers/" $BASE/mesa.py
|
||||
sed -i "s/ctypes.CDLL('.\+')/(dll := _try_dlopen_tinymesa_cpu())/" $BASE/mesa.py
|
||||
echo "def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py
|
||||
echo "def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py
|
||||
sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/mesa.py
|
||||
# bitfield bug in clang2py
|
||||
sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/mesa.py
|
||||
|
||||
@@ -121,6 +121,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file, tol):
|
||||
print("test vs onnx passed")
|
||||
return timings
|
||||
|
||||
def bench(run, inputs):
|
||||
from extra.bench_log import WallTimeEvent, BenchEvent
|
||||
for _ in range(10):
|
||||
with WallTimeEvent(BenchEvent.STEP):
|
||||
run(**inputs).numpy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
inputs, outputs = compile(onnx_file)
|
||||
@@ -131,3 +137,5 @@ if __name__ == "__main__":
|
||||
if not getenv("FLOAT16"):
|
||||
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
|
||||
|
||||
if getenv("BENCHMARK_LOG", ""):
|
||||
bench(pickle_loaded, inputs)
|
||||
|
||||
@@ -99,6 +99,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
N = 1
|
||||
@@ -112,19 +113,22 @@ if __name__ == "__main__":
|
||||
|
||||
model = StableDiffusionV2(**params)
|
||||
|
||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||
weights_fn = args.weights_fn
|
||||
if not weights_fn:
|
||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||
if not args.fakeweights:
|
||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||
weights_fn = args.weights_fn
|
||||
if not weights_fn:
|
||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||
|
||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
if k.startswith("model"):
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
v.replace(v.cast(dtypes.float16))
|
||||
|
||||
Tensor.realize(*get_state_dict(model).values())
|
||||
|
||||
c = { "crossattn": model.cond_stage_model(args.prompt) }
|
||||
uc = { "crossattn": model.cond_stage_model("") }
|
||||
|
||||
@@ -263,14 +263,16 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||
if not args.fakeweights:
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
|
||||
9
test/external/external_benchmark_schedule.py
vendored
9
test/external/external_benchmark_schedule.py
vendored
@@ -2,7 +2,8 @@ from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer
|
||||
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
|
||||
from tinygrad.codegen.late.control_flow import linearize
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -39,7 +40,7 @@ if __name__ == "__main__":
|
||||
with Timing("***** model linearize in "):
|
||||
uops_line = []
|
||||
for u in rewritten_uops:
|
||||
uops_line.append(apply_rewrites(u, rewrites_for_linearizer))
|
||||
uops_line.append(linearize(u))
|
||||
with Timing("***** model verify in "):
|
||||
for u in uops_line: type_verify(u.arg.lst)
|
||||
print(sum(len(u.arg.lst) for u in uops_line))
|
||||
for u in uops_line: type_verify(u)
|
||||
print(sum(len(u) for u in uops_line))
|
||||
|
||||
4
test/external/speed_v_theoretical.py
vendored
4
test/external/speed_v_theoretical.py
vendored
@@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
|
||||
# theoretical is nv_tflops=165, amd_tflops=123
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=65)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=60)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=115, amd_tflops=60)
|
||||
|
||||
# theoretical is nv_gbs=1008, amd_gbs=960
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=820, amd_gbs=750)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
@@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
uops = get_program(ast, opts=opt).uops
|
||||
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
|
||||
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
|
||||
for u in uops:
|
||||
if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG:
|
||||
@@ -214,8 +214,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
else:
|
||||
assert u.src[1].op in GroupOp.ALU
|
||||
assert begin_range < uops.index(u) < end_range
|
||||
# children of STORE are placed after ENDRANGE
|
||||
if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src):
|
||||
# children of END are placed after ENDRANGE
|
||||
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
|
||||
assert end_range < uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
@@ -400,7 +400,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1
|
||||
assert len([u for u in uops if u.op is Ops.IF and u.src[1] == barrier]) == 1
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
|
||||
@@ -2602,6 +2602,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "AMD" and CI, "remu failure?")
|
||||
def test_avg_pool3d_failure(self):
|
||||
with Context(NOOPT=0):
|
||||
helper_test_op([(1,1,16,16,16)],
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import Context, GlobalCounters, CI, CPU_LVP, getenv
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import Context, GlobalCounters, CI, getenv
|
||||
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
def test_assign_permuted(self):
|
||||
@@ -40,7 +42,7 @@ elif getenv("BIG") > 0:
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
@unittest.skipIf(CPU_LVP, "broken in LVP")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention_bw(self):
|
||||
def fa_bw():
|
||||
@@ -62,7 +64,7 @@ class TestPcontig(unittest.TestCase):
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2):
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.helpers import SPLIT_REDUCEOP
|
||||
|
||||
class TestTensorUOp(unittest.TestCase):
|
||||
@@ -93,7 +93,6 @@ class TestTensorUOp(unittest.TestCase):
|
||||
out.realize()
|
||||
self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist())
|
||||
|
||||
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, allow_any_len=True, src=(UPat(), UPat((Ops.REDUCE_AXIS, Ops.REDUCE))))))
|
||||
@unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP")
|
||||
class TestReduceOp(unittest.TestCase):
|
||||
def test_no_split_reduce_kernel(self):
|
||||
@@ -101,23 +100,18 @@ class TestReduceOp(unittest.TestCase):
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 1
|
||||
assert reduce_kernel.match(sched[0].ast, {})
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
assert reduce_kernel.match(s.ast, {})
|
||||
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
assert reduce_kernel.match(s.ast, {})
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -665,19 +665,6 @@ class TestUOpGraph(unittest.TestCase):
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
|
||||
def test_switched_range_order(self):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
cf = UOp.const(dtypes.float, 0.0)
|
||||
r1 = UOp.range(2, 0)
|
||||
r2 = UOp.range(2, 1)
|
||||
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
|
||||
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
||||
uops = to_uops_list([store])
|
||||
ranges = [x for x in uops if x.op is Ops.RANGE]
|
||||
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
@track_rewrites()
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
|
||||
|
||||
@@ -845,8 +832,6 @@ class TestIFUOps(unittest.TestCase):
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
@@ -863,8 +848,6 @@ class TestIFUOps(unittest.TestCase):
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
# this will be fixed with the merge gated stores bounty
|
||||
@unittest.expectedFailure
|
||||
@@ -879,8 +862,6 @@ class TestIFUOps(unittest.TestCase):
|
||||
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
class TestUOpTags(unittest.TestCase):
|
||||
def test_inc_by_one(self):
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
import unittest, random
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import print_uops, UOp, Ops
|
||||
from tinygrad.codegen.late.linearize import block_reorder
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
|
||||
def is_toposorted(lst:list[UOp]):
|
||||
seen = set()
|
||||
for u in lst:
|
||||
if any(p not in seen for p in u.src): return False
|
||||
seen.add(u)
|
||||
return True
|
||||
|
||||
class TestBlockReorder(unittest.TestCase):
|
||||
def _test_randomize(self, golden:list[UOp]):
|
||||
# test random order is always same
|
||||
for _ in range(50):
|
||||
# shuffle and form a valid toposort
|
||||
lst = golden[:]
|
||||
random.shuffle(lst)
|
||||
topolst = []
|
||||
for u in lst:
|
||||
for p in u.toposort():
|
||||
if p not in topolst: topolst.append(p)
|
||||
assert is_toposorted(topolst)
|
||||
|
||||
for x,y in zip(golden, this_order:=block_reorder(topolst)):
|
||||
if x is not y:
|
||||
print_uops(golden)
|
||||
print_uops(this_order)
|
||||
self.assertIs(x, y)
|
||||
|
||||
def _test_render(self, golden:list[UOp]):
|
||||
return OpenCLRenderer().render(golden)
|
||||
|
||||
def test_loads(self):
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2)
|
||||
v1 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx0")
|
||||
v2 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx1")
|
||||
v1 = v1*27
|
||||
v2 = v2*4
|
||||
loads = [
|
||||
a.index(v1).load(dtype=dtypes.float),
|
||||
a.index(v1+1).load(dtype=dtypes.float),
|
||||
a.index(v1+2).load(dtype=dtypes.float),
|
||||
a.index(v1+3).load(dtype=dtypes.float),
|
||||
b.index(v2).load(dtype=dtypes.float),
|
||||
b.index(v2+1).load(dtype=dtypes.float),
|
||||
b.index(v2+2).load(dtype=dtypes.float),
|
||||
b.index(v2+3).load(dtype=dtypes.float)]
|
||||
#random.shuffle(loads)
|
||||
sink = c.store(sum(loads)).sink()
|
||||
|
||||
# determine golden order
|
||||
golden = block_reorder(list(sink.toposort()))
|
||||
|
||||
# render for test
|
||||
print(self._test_render(golden))
|
||||
#print_uops(golden)
|
||||
|
||||
# assert the loads are in this order
|
||||
self.assertListEqual([g.src[0].src[1].render() for g in golden if g.op is Ops.LOAD],
|
||||
['(gidx1*4)', '((gidx1*4)+1)', '((gidx1*4)+2)', '((gidx1*4)+3)',
|
||||
'(gidx0*27)', '((gidx0*27)+1)', '((gidx0*27)+2)', '((gidx0*27)+3)'])
|
||||
|
||||
# assert math is after loads
|
||||
first_math = [i for i,g in enumerate(golden) if g.op is Ops.ADD and g.dtype == dtypes.float][0]
|
||||
assert not any(x.op is Ops.LOAD for x in golden[first_math:])
|
||||
|
||||
# confirm the sort is stable
|
||||
self._test_randomize(golden)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -20,8 +20,8 @@ class TestKernelize(unittest.TestCase):
|
||||
self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2)
|
||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
||||
# input Tensor and user contiguous kernelize
|
||||
self.assertIs(a0.uop.base.op, Ops.ASSIGN)
|
||||
self.assertIs(a.uop.base.op, Ops.ASSIGN)
|
||||
self.assertIs(a0.uop.base.op, Ops.AFTER)
|
||||
self.assertIs(a.uop.base.op, Ops.AFTER)
|
||||
|
||||
def test_two_reduce_w_add(self):
|
||||
a = Tensor.ones(16,16).contiguous()
|
||||
@@ -31,7 +31,7 @@ class TestKernelize(unittest.TestCase):
|
||||
# NOTE: the +1 is fused with a1, so a1 is not kernelized
|
||||
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
|
||||
# the input to the REDUCE_AXIS is an ASSIGN though
|
||||
self.assertIs(a1.uop.base.src[0].base.op, Ops.ASSIGN)
|
||||
self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.symbolic import simplify_valid
|
||||
from tinygrad.helpers import Context
|
||||
from .test_uop_symbolic import check_uop_against_string
|
||||
from test.unit.test_uop_symbolic import check_uop_against_string
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
|
||||
@@ -148,6 +148,14 @@ class TestViz(BaseTestViz):
|
||||
a2 = uop_to_json(a)[id(a)]
|
||||
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
|
||||
|
||||
def test_colored_label_multiline(self):
|
||||
arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta")
|
||||
src = [Tensor.empty(1).uop for _ in range(10)]
|
||||
a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg)
|
||||
exec_rewrite(a, [PatternMatcher([])])
|
||||
a2 = next(get_viz_details(0, 0))["graph"][id(a)]
|
||||
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
|
||||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
|
||||
b = a.replace(op=Ops.CONST)
|
||||
|
||||
@@ -14,10 +14,10 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -30,12 +30,6 @@ class RewriteStep:
|
||||
|
||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||
|
||||
rewrites_for_linearizer = [
|
||||
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
||||
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
||||
RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
|
||||
RewriteStep(pm_finalize, name="Linearizer: Finalize")]
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
|
||||
# cache with the values of the context vars
|
||||
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
|
||||
@@ -101,11 +95,15 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
||||
|
||||
# return the list (with optional linearizer)
|
||||
return ret + (rewrites_for_linearizer if linearizer else [])
|
||||
# this was the linearizer
|
||||
ret.append(RewriteStep(pm_merge_ends, name="merge ends"))
|
||||
ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True))
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
|
||||
# return the list
|
||||
return ret
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
@@ -119,6 +117,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
Linear program in UOps.
|
||||
"""
|
||||
|
||||
lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst)
|
||||
lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None))
|
||||
if __debug__: type_verify(lst)
|
||||
return lst
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import math
|
||||
import math, functools, operator
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
|
||||
from tinygrad.helpers import all_int, dedup, get_contraction
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -87,7 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
except ValueError: continue
|
||||
return s.substitute(subs)
|
||||
|
||||
def add_barrier_and_if(buf:UOp, e:UOp):
|
||||
# TODO: this is not generic
|
||||
local_ranges = [x for x in e.ended_ranges if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE]
|
||||
if len(local_ranges) == 0: return None
|
||||
return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), e.barrier())))
|
||||
|
||||
pm_add_gpudims = PatternMatcher([
|
||||
# add gpudims must be last
|
||||
(UPat(Ops.SINK, name="s"), add_gpudims),
|
||||
# add barrier and if
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.END, name="e"))), add_barrier_and_if),
|
||||
])
|
||||
|
||||
102
tinygrad/codegen/late/control_flow.py
Normal file
102
tinygrad/codegen/late/control_flow.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
|
||||
|
||||
def linearize(u:UOp) -> list[UOp]:
|
||||
lst = list(u.toposort())
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
for u in reversed(lst):
|
||||
in_degree[u] = 0
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
||||
priority = [0] + [priorities[x] for x in local_children[u]]
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
# ranges are scheduled as late as possible so anything that can be outside is
|
||||
#if u.op is Ops.RANGE: priority = [2000]
|
||||
# move defines and consts to the top
|
||||
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
|
||||
priorities[u] = min(priority)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
||||
|
||||
# then force then to be toposorted in as close to the ideal order as possible
|
||||
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
||||
newlst = []
|
||||
while heap:
|
||||
newlst.append(u:=heapq.heappop(heap)[1])
|
||||
for v in local_children[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
||||
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
return newlst
|
||||
|
||||
class CFGContext:
|
||||
def __init__(self, sink:UOp):
|
||||
# there are 3 relationships between ranges:
|
||||
# nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y
|
||||
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y
|
||||
# independent, endrange y is not a dependency of endrange x
|
||||
# everything is nested inside the sink
|
||||
deps: dict[UOp, set[UOp]] = {}
|
||||
nesting: dict[UOp, UOp] = {}
|
||||
for u in sink.toposort():
|
||||
deps[u] = set().union(*(deps[s] for s in u.src))
|
||||
if u.op in (Ops.END, Ops.ENDIF, Ops.SINK):
|
||||
nesting |= {x:u for x in deps[u] if x.op in (Ops.END, Ops.ENDIF) and (u.op is Ops.SINK or u.src[0] in deps[x]) and x not in nesting}
|
||||
if u.op in (Ops.RANGE, Ops.END, Ops.IF, Ops.ENDIF): deps[u] |= {u}
|
||||
|
||||
self.edges: dict[UOp, UOp] = {}
|
||||
siblings: dict[UOp, list[UOp]] = {}
|
||||
for k,vv in nesting.items(): siblings.setdefault(vv, []).append(k)
|
||||
for k,v in siblings.items():
|
||||
# range/if that have dependencies on other siblings need to run after them
|
||||
order = sorted(v, key=lambda x: len(deps[x].intersection(v)))
|
||||
zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[0]] + order, order)
|
||||
for x,y in zipped:
|
||||
# TODO: is this check correct?
|
||||
if y.src[0] not in x.backward_slice_with_self:
|
||||
self.edges[y.src[0]] = x
|
||||
|
||||
pm_add_control_flow = PatternMatcher([
|
||||
(UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
])
|
||||
|
||||
def do_merge_ends(s:UOp):
|
||||
# NOTE: this can fail
|
||||
stacked: dict[UOp, list[UOp]] = {}
|
||||
dangling_ifs = []
|
||||
for x in s.toposort():
|
||||
if x.op in {Ops.END, Ops.ENDIF}:
|
||||
assert x.op is not Ops.END or x.arg == 1, "ends must be single ends for linearizer"
|
||||
stacked.setdefault(x.src[0], []).append(x)
|
||||
if x.op is Ops.IF: dangling_ifs.append(x)
|
||||
dangling_ifs = [x for x in dangling_ifs if x not in stacked]
|
||||
replaces = {}
|
||||
for k,v in stacked.items():
|
||||
if len(v) == 1: continue
|
||||
rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=x[0].arg)
|
||||
for x in v: replaces[x] = rep
|
||||
if not len(replaces) and not len(dangling_ifs): return None
|
||||
ret = s.substitute(replaces)
|
||||
if len(dangling_ifs):
|
||||
assert len(dangling_ifs) == 1, "we only support 1 dangling if"
|
||||
ret = ret.replace(src=(UOp(Ops.ENDIF, src=(dangling_ifs[0], *ret.src)),))
|
||||
return ret
|
||||
|
||||
pm_merge_ends = PatternMatcher([
|
||||
# for renderering and linearizing, all ends must end one loop
|
||||
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
|
||||
(UPat(Ops.SINK, name="s"), do_merge_ends),
|
||||
])
|
||||
@@ -123,7 +123,7 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines).or_after(name="buf")), UPat.var("vec"))), expand_index),
|
||||
# GEP after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
@@ -242,11 +242,13 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# CAST after AFTER
|
||||
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
@@ -266,9 +268,9 @@ pm_render = PatternMatcher([
|
||||
UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
# gate any stores that aren't gated with ifs
|
||||
# gate any stores that aren't gated with if/endif pairs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
lambda store,idx: UOp(Ops.ENDIF, src=(uif:=UOp(Ops.IF, src=(idx.src[2],)), UOp(Ops.STORE, src=store.src[:2]+(uif,)+store.src[2:]))) if \
|
||||
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
||||
])
|
||||
|
||||
@@ -293,15 +295,17 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
topo = inp.toposort()
|
||||
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
||||
ended_ranges = flatten([x.src[:x.arg] for x in topo if x.op is Ops.END])
|
||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
|
||||
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
|
||||
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
|
||||
acc.index(UOp.const(dtypes.int, 0)).store(identity)
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
if len(reduce_range) == 0: return ret
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(ends=reduce_range[::-1])).index(UOp.const(dtypes.int, 0)).load()
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
|
||||
@@ -87,7 +87,7 @@ expander = PatternMatcher([
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
@@ -145,8 +145,7 @@ def fix_group_for_reduce(x:UOp):
|
||||
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
|
||||
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
|
||||
|
||||
# gate with an if on the store + do the final reduce
|
||||
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
|
||||
# do the final reduce (if/barrier are added in gpudims step)
|
||||
return buf.reduce(*reduce_loop, arg=x.arg)
|
||||
|
||||
pm_pre_expander = PatternMatcher([
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate
|
||||
from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, int] = {}
|
||||
|
||||
# get local children and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
for u in reversed(lst):
|
||||
in_degree[u] = 0
|
||||
for s in u.src:
|
||||
if s in in_this_block:
|
||||
local_children[s].append(u)
|
||||
in_degree[u] += 1
|
||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
||||
priority = [0] + [priorities[x] for x in local_children[u]]
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
priorities[u] = min(priority)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
||||
|
||||
# then force then to be toposorted in as close to the ideal order as possible
|
||||
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
||||
newlst = []
|
||||
while heap:
|
||||
newlst.append(u:=heapq.heappop(heap)[1])
|
||||
for v in local_children[u]:
|
||||
in_degree[v] -= 1
|
||||
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
||||
|
||||
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
||||
return newlst
|
||||
|
||||
# ***** basic block *****
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...] = ()
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
||||
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
||||
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
||||
|
||||
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
||||
|
||||
# ***** block context *****
|
||||
|
||||
@dataclass
|
||||
class BlockContext:
|
||||
child_count: dict[UOp, int]
|
||||
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
||||
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp) -> BlockContext:
|
||||
# get children and all block contexts
|
||||
ctx = BlockContext({}, {}, {})
|
||||
for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL):
|
||||
this_block_ctx: list[UOp] = []
|
||||
ctx.child_count[u] = 0
|
||||
|
||||
# get children and accumulate the last_ctx
|
||||
for s in u.src:
|
||||
if s.op is Ops.SPECIAL: continue
|
||||
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
||||
ctx.child_count[s] += 1
|
||||
this_block_ctx += ctx.last_ctx(s)
|
||||
|
||||
# save the block ctx. SINK never has anything
|
||||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
||||
|
||||
# RANGE/IF add to the next ctx
|
||||
# STORE/ASSIGN subtract from the next ctx
|
||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
||||
|
||||
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
||||
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
||||
while len(ends_to_add):
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
return base_block
|
||||
|
||||
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
if x.op is Ops.BLOCKSTART:
|
||||
current_ctx, child_ctx = x.arg
|
||||
lst = list(x.src)
|
||||
child_count = 1
|
||||
else:
|
||||
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
||||
lst = [x]
|
||||
|
||||
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
||||
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
||||
blockseeds = defaultdict(list)
|
||||
|
||||
# add the srcs of this to the frontier
|
||||
# NOTE: things may be in here multiple times, that's okay
|
||||
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
||||
while len(frontier_nodes):
|
||||
u = frontier_nodes.pop(0)
|
||||
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
||||
# count is correct
|
||||
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
||||
# block has same context, merge it, and put the srcs on the frontier
|
||||
lst.append(u)
|
||||
frontier_nodes.extend(u.src[::-1])
|
||||
else:
|
||||
# block has different context, add it to blockseeds
|
||||
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
||||
del unmergable[u]
|
||||
else:
|
||||
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
||||
unmergable[u] += 1
|
||||
|
||||
# add unmergables to sources
|
||||
srcs = []
|
||||
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt
|
||||
|
||||
# add blockseeds, with blockends as needed
|
||||
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
||||
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
||||
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
||||
|
||||
lst = lst[::-1]
|
||||
if BLOCK_REORDER: lst = block_reorder(lst)
|
||||
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
# we prevent the source of the SPECIAL from being linearized since its not part of the kernel
|
||||
def raise_bottom_up_gate(): raise BottomUpGate()
|
||||
|
||||
block_create = PatternMatcher([
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
||||
(UPat(Ops.SPECIAL), raise_bottom_up_gate)
|
||||
])
|
||||
|
||||
# ***** blockend merging ****
|
||||
|
||||
def merge_blockends(sink:UOp) -> UOp|None:
|
||||
# only run on the final BLOCK with the SINK in it
|
||||
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
for be in sink.toposort():
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
if len(new_forks) == 0: return None
|
||||
return sink.substitute(new_forks)
|
||||
|
||||
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
||||
|
||||
# ***** block merging ****
|
||||
|
||||
def merge_block(x:UOp):
|
||||
unmergable_blocks, mergable_blocks = [], []
|
||||
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
||||
for y in x.src:
|
||||
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
||||
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
||||
else: unmergable_blocks.append(y)
|
||||
for k,v in mergable_dict.items():
|
||||
if v == k.arg.cnt: mergable_blocks.append(k)
|
||||
else: unmergable_blocks.extend([k]*v)
|
||||
if len(mergable_blocks) == 0: return None
|
||||
del mergable_dict
|
||||
|
||||
# create the block
|
||||
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
||||
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
||||
|
||||
def remove_blockend(x:UOp):
|
||||
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
||||
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
||||
|
||||
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
|
||||
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
||||
parent_block = parent_blocks[0]
|
||||
assert len(parent_blocks) == parent_block.arg.cnt
|
||||
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
|
||||
late_ops = list(x.arg.lst)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
# peephole opt, remove any BARRIERs next to each other
|
||||
for i in range(len(late_ops)-1):
|
||||
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
|
||||
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
||||
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
||||
# else the whole context ended by the blockend is already in this block and we can safely turn it into a block
|
||||
return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt))
|
||||
|
||||
block_merge = PatternMatcher([
|
||||
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
||||
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
||||
])
|
||||
|
||||
# ****** finalize ******
|
||||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
@@ -4,8 +4,8 @@ from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -64,21 +64,8 @@ class Scheduler:
|
||||
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
|
||||
|
||||
def _globalizable_rngs(self) -> list[UOp]:
|
||||
store_rngs = self.ast.src[0].src[2:]
|
||||
|
||||
# filter any not in local stores
|
||||
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
|
||||
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
|
||||
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
|
||||
# filter any not in reduces
|
||||
# TODO: enable this
|
||||
"""
|
||||
reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
"""
|
||||
|
||||
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[-1] == AxisType.LOOP] if store_rngs else []
|
||||
# all ranges that end before any STOREs
|
||||
return [x for x in self.ast.toposort(lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in self.ast.ranges]
|
||||
|
||||
def convert_loop_to_global(self):
|
||||
if not self.opts.has_local: return None
|
||||
@@ -89,11 +76,11 @@ class Scheduler:
|
||||
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
store_rngs = flatten([x.src[2:] for x in self.ast.src])
|
||||
globalizible_rngs = self._globalizable_rngs()
|
||||
ret = []
|
||||
for x,r in zip(self.axis_types, self.rngs):
|
||||
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
|
||||
elif r not in store_rngs and x == AxisType.LOOP: ret.append("BLACK")
|
||||
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK")
|
||||
else: ret.append(axis_colors[x])
|
||||
return ret
|
||||
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
|
||||
@@ -348,7 +335,7 @@ def apply_opts(ctx:Renderer, ast:UOp):
|
||||
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
||||
if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD):
|
||||
if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice):
|
||||
k = hand_coded_optimizations(k)
|
||||
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
||||
|
||||
|
||||
@@ -156,7 +156,9 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
|
||||
if lib in seen_libs: continue
|
||||
# filter out kernels that use 1000x more compute than the smallest
|
||||
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
||||
if least_compute_ops*1000 < this_compute_ops: continue
|
||||
if least_compute_ops*1000 < this_compute_ops:
|
||||
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}")
|
||||
continue
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
|
||||
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
|
||||
@@ -13,14 +13,16 @@ def flatten_range(r:UOp):
|
||||
pm_flatten_range = PatternMatcher([
|
||||
# real ranges only
|
||||
(UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range),
|
||||
# END is only on RANGES. TODO: this is copied from symbolic
|
||||
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
|
||||
])
|
||||
|
||||
def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}])
|
||||
def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE]
|
||||
i = range_start[u.op]
|
||||
while i < len(u.src)-1:
|
||||
r0, r1 = u.src[i], u.src[i+1]
|
||||
i = 0
|
||||
while i < len(u.ended_ranges)-1:
|
||||
r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1]
|
||||
# check same type
|
||||
if r0.arg[-1] == r1.arg[-1]:
|
||||
# check if the ranges to merge are in the same reduces
|
||||
@@ -39,7 +41,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
return u
|
||||
|
||||
pm_simplify_ranges = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent),
|
||||
(UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent),
|
||||
])
|
||||
|
||||
def mark_range_mod(ctx, r:UOp, c:UOp):
|
||||
@@ -57,7 +59,7 @@ def do_substitute(ctx, x: UOp):
|
||||
|
||||
def dont_sub_ranges_for_image(ctx, x:UOp):
|
||||
if isinstance(x.src[0].dtype, ImageDType):
|
||||
for s in x.src[1:]: ctx[s] = None
|
||||
for s in x.src[0].ranges: ctx[s] = None
|
||||
|
||||
pm_split_ranges = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod),
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
|
||||
from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import unwrap_class_type
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -163,6 +163,7 @@ class Buffer:
|
||||
return self._trace_num
|
||||
@property
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
@suppress_finalizing
|
||||
def __del__(self): (not hasattr(self, '_buf')) or self.deallocate()
|
||||
def __repr__(self):
|
||||
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
||||
|
||||
@@ -22,18 +22,18 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src:
|
||||
if s.op is Ops.ASSIGN:
|
||||
if s.op is Ops.AFTER:
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children[ss.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
@@ -43,7 +43,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
|
||||
# linearize KERNEL UOps into ScheduleItems in BFS order
|
||||
|
||||
|
||||
@@ -170,9 +170,9 @@ SPEC = ContextVar("SPEC", 0)
|
||||
# TODO: disable by default due to speed
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
|
||||
REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0)
|
||||
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
name: str
|
||||
|
||||
@@ -36,7 +36,7 @@ class BatchNorm:
|
||||
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
|
||||
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
|
||||
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
|
||||
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:
|
||||
|
||||
@@ -28,9 +28,12 @@ class Estimates:
|
||||
mult_stack: list[sint] = []
|
||||
dont_count: set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
def range_gate(x): return x.op is not Ops.RANGE
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
@@ -45,7 +48,7 @@ class Estimates:
|
||||
mults *= cast(sint, u.src[0].ssimplify())
|
||||
# SPECIAL are already counted in mults
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Literal, Callable, cast
|
||||
import os, math, sys
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.codegen.opt import tc, axis_letters
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
@@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
||||
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
# r method accesses
|
||||
(UPat(Ops.RANGE, name="x"),
|
||||
@@ -144,6 +144,9 @@ class CStyleLanguage(Renderer):
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.AFTER:
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
@@ -160,7 +163,7 @@ class CStyleLanguage(Renderer):
|
||||
# naming
|
||||
prefix = None
|
||||
if u.op is Ops.SPECIAL: r[u] = u.arg
|
||||
elif u.op is Ops.RANGE: r[u] = "ridx"+range_str(u)
|
||||
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
||||
@@ -170,7 +173,7 @@ class CStyleLanguage(Renderer):
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||
|
||||
@@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([
|
||||
f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n"
|
||||
f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
(UPat(Ops.END, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n"
|
||||
f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
|
||||
@@ -167,6 +167,9 @@ class LLVMRenderer(Renderer):
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.AFTER:
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, cast
|
||||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -169,10 +169,13 @@ class NIRRenderer(Renderer):
|
||||
def render(self, uops:list[UOp]):
|
||||
self.prerender(uops)
|
||||
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
|
||||
self.r, self.param_idx, ranges = {}, 0, []
|
||||
self.r: dict[UOp, Any] = {}
|
||||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op == Ops.NOOP or u.op == Ops.INDEX: pass
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
elif u.op == Ops.SINK:
|
||||
if u.arg is not None: self.b.shader.contents.info.name = mesa.char_pointer_cast(u.arg.function_name)
|
||||
elif u.op == Ops.DEFINE_LOCAL:
|
||||
@@ -183,7 +186,7 @@ class NIRRenderer(Renderer):
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
|
||||
elif u.op == Ops.ENDRANGE:
|
||||
elif u.op == Ops.END:
|
||||
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]),
|
||||
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break))
|
||||
mesa.nir_pop_loop(self.b, None)
|
||||
|
||||
@@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
||||
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
||||
(UPat(Ops.END, name="x", src=(UPat.var("src0"),), allow_any_len=True), lambda ctx, x, src0: [
|
||||
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
|
||||
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
|
||||
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
||||
@@ -179,6 +179,9 @@ class PTXRenderer(Renderer):
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
@@ -216,7 +219,7 @@ class PTXRenderer(Renderer):
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
if prefix: r[u] = ssa(prefix, u, dtype)
|
||||
|
||||
@@ -7,13 +7,14 @@
|
||||
# LONGDOUBLE_SIZE is: 16
|
||||
#
|
||||
import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers
|
||||
def brew_prefix():
|
||||
try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip()
|
||||
except Exception: return ''
|
||||
def brew_path(nm):
|
||||
try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib"
|
||||
except Exception: return 'failed'
|
||||
PATHS_TO_TRY = [
|
||||
(BASE:=os.getenv('MESA_PATH', f"/usr{'/local/' if helpers.OSX else '/'}lib"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so'),
|
||||
f'{BASE}/libtinymesa{EXT}',
|
||||
f'{brew_prefix()}/lib/libtinymesa_cpu.dylib',
|
||||
brew_path('tinymesa_cpu'),
|
||||
brew_path('tinymesa'),
|
||||
]
|
||||
def _try_dlopen_tinymesa_cpu():
|
||||
library = ctypes.util.find_library("tinymesa_cpu")
|
||||
@@ -6087,7 +6088,7 @@ struct_nir_op_info._fields_ = [
|
||||
|
||||
nir_op_info = struct_nir_op_info
|
||||
try: nir_op_infos = (struct_nir_op_info * 489).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_op_infos')
|
||||
except AttributeError: pass
|
||||
except (AttributeError, ValueError): pass
|
||||
try:
|
||||
nir_op_is_selection = _libraries['FIXME_STUB'].nir_op_is_selection
|
||||
nir_op_is_selection.restype = ctypes.c_bool
|
||||
@@ -8118,7 +8119,7 @@ c__EA_nir_intrinsic_index_flag = ctypes.c_uint32 # enum
|
||||
nir_intrinsic_index_flag = c__EA_nir_intrinsic_index_flag
|
||||
nir_intrinsic_index_flag__enumvalues = c__EA_nir_intrinsic_index_flag__enumvalues
|
||||
try: nir_intrinsic_index_names = (ctypes.POINTER(ctypes.c_char) * 75).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_index_names')
|
||||
except AttributeError: pass
|
||||
except (AttributeError, ValueError): pass
|
||||
class struct_nir_intrinsic_instr(Structure):
|
||||
pass
|
||||
|
||||
@@ -8242,7 +8243,7 @@ struct_nir_intrinsic_info._fields_ = [
|
||||
|
||||
nir_intrinsic_info = struct_nir_intrinsic_info
|
||||
try: nir_intrinsic_infos = (struct_nir_intrinsic_info * 732).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_infos')
|
||||
except AttributeError: pass
|
||||
except (AttributeError, ValueError): pass
|
||||
try:
|
||||
nir_intrinsic_src_components = _libraries['libtinymesa_cpu.so'].nir_intrinsic_src_components
|
||||
nir_intrinsic_src_components.restype = ctypes.c_uint32
|
||||
@@ -19877,4 +19878,4 @@ __all__ = \
|
||||
'union_util_format_description_0', 'util_format_colorspace',
|
||||
'util_format_layout', 'va_list']
|
||||
lvp_nir_options = gzip.decompress(base64.b64decode('H4sIAAAAAAAAA2NgZGRkYGAAkYxgCsQFsxigwgwQBoxmhCqFq2WEKwIrAEGIkQxoAEMALwCqVsCiGUwLMHA0QPn29nBJkswHANb8YpH4AAAA'))
|
||||
def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)')
|
||||
def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)')
|
||||
|
||||
@@ -52,11 +52,11 @@ class PythonProgram:
|
||||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is Ops.ENDRANGE:
|
||||
if uop is Ops.END:
|
||||
loop_ends[idp[0]] = i
|
||||
i = idp[0]
|
||||
continue
|
||||
@@ -72,7 +72,8 @@ class PythonProgram:
|
||||
if g: _store(m, o+j, v, dtp[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
if uop is Ops.AFTER: ul[i] = inp[0]
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(dtype, PtrDType), dtype
|
||||
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
|
||||
|
||||
@@ -58,7 +58,7 @@ class LLVMCompiler(Compiler):
|
||||
self.diag_msgs.append(msg)
|
||||
self.handle_diag = handle_diag
|
||||
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
|
||||
super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
|
||||
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ class PTXCompiler(Compiler):
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
|
||||
class NVPTXCompiler(PTXCompiler):
|
||||
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
|
||||
def __init__(self, arch:str):
|
||||
nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint())))
|
||||
super().__init__(arch, cache_key="nv_ptx")
|
||||
def compile(self, src:str) -> bytes:
|
||||
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
|
||||
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
|
||||
|
||||
@@ -52,11 +52,11 @@ class IndexingContext:
|
||||
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.KERNEL}: return None
|
||||
if x.op is Ops.ASSIGN and x.src[1].op is Ops.KERNEL: return None
|
||||
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
|
||||
@@ -2,10 +2,9 @@ from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate
|
||||
from tinygrad.uop.symbolic import symbolic_flat
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
|
||||
from tinygrad.helpers import DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -137,6 +136,9 @@ def cleanup_dead_axes(b:UOp):
|
||||
# move the tag to the expand. NOTE: this expand tag might not survive
|
||||
return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag)
|
||||
|
||||
def gate_substitute(ctx, b:UOp) -> None:
|
||||
if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()
|
||||
pm_gate_substitute = PatternMatcher([(UPat(GroupOp.All, name="b"), gate_substitute)], compiled=False)
|
||||
# if a buffer is being stored just for permutes or something, remove it
|
||||
# we want to reexpress the indexes of idx2 in terms of the implied b1
|
||||
def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
@@ -179,11 +181,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# if it makes it here, the bufferize is removed
|
||||
# this is the ranges replaced
|
||||
# NOTE: if buf src is a const, we don't replace it
|
||||
if REAL_SUBSTITUTE:
|
||||
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST})
|
||||
else:
|
||||
replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST])
|
||||
return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2]))))
|
||||
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
|
||||
|
||||
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
|
||||
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
|
||||
@@ -243,7 +241,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
# TODO: add cache to fix n^2
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
|
||||
@@ -278,7 +276,8 @@ def bufferize_to_store(x:UOp):
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
# TODO: assign_mops here
|
||||
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag)
|
||||
do_store = assign_target.replace(dtype=sdtype).store(assign_src).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE])
|
||||
ret = assign_target.src[0].after(do_store)
|
||||
mops = []
|
||||
walk = assign_mops
|
||||
while walk is not assign_mops.base:
|
||||
@@ -290,8 +289,8 @@ def bufferize_to_store(x:UOp):
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag)
|
||||
ret = ret.forced_reshape(shape)
|
||||
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE])
|
||||
ret = buf.after(do_store).forced_reshape(shape)
|
||||
# TODO: is this right? what if it's offset
|
||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
@@ -302,9 +301,8 @@ def bufferize_to_store(x:UOp):
|
||||
tag = x.arg.device
|
||||
if tag is None: tag = UOp.unique().arg # TODO: hack
|
||||
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
|
||||
# store has the other dtype here
|
||||
# TODO: how is this unified?
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape)
|
||||
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).end(ends=[x for x in rngs if x.op is Ops.RANGE])
|
||||
return buf.after(do_store).reshape(shape)
|
||||
|
||||
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
@@ -336,12 +334,13 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
|
||||
ctx.vars[b] = None
|
||||
return b.src[0]
|
||||
|
||||
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
|
||||
buf = assign.as_buf()
|
||||
def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
||||
if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None
|
||||
buf = after.as_buf()
|
||||
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
|
||||
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
|
||||
assert buf not in ctx.map
|
||||
ctx.map[buf] = assign
|
||||
ctx.map[buf] = after
|
||||
return buf
|
||||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
@@ -351,7 +350,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
return ret
|
||||
|
||||
def find_bufs(x:UOp):
|
||||
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.ASSIGN) if s.op is Ops.INDEX]
|
||||
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX]
|
||||
read_from: dict[UOp, Ops] = {}
|
||||
if any((buf:=idx.as_buf()).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs):
|
||||
raise RuntimeError(f"cycle detected while indexing {buf}")
|
||||
@@ -360,7 +359,7 @@ to_define_global = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), find_bufs),
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign),
|
||||
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
|
||||
|
||||
# HACK in case any CONSTs were replaced
|
||||
# this is only needed if you are using symbolic
|
||||
@@ -389,16 +388,9 @@ rangeify_codegen = PatternMatcher([
|
||||
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
|
||||
# TODO: this can be moved into codegen
|
||||
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
|
||||
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
|
||||
|
||||
# TODO: hack for group for reduce
|
||||
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
|
||||
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
|
||||
(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
|
||||
.f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
])
|
||||
|
||||
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
|
||||
@@ -419,9 +411,8 @@ class Kernel:
|
||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp):
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
if len(x.ranges): return None
|
||||
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
@@ -431,16 +422,22 @@ def split_store(ctx:list[UOp], x:UOp):
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \
|
||||
if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
|
||||
for u in ret.toposort():
|
||||
# TODO: this can be wrong if there's multiple of these
|
||||
if u.op in {Ops.COPY, Ops.BUFFER_VIEW}:
|
||||
ret = u
|
||||
break
|
||||
else:
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
|
||||
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
|
||||
return x.as_buf().assign(kernel)
|
||||
return kernel
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), split_store),
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:list[UOp], x:UOp):
|
||||
@@ -471,25 +468,6 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
def do_sub_recurse(s:UOp):
|
||||
x,keys,values = s.src[0], s.src[1].src, s.src[2].src
|
||||
# SUBSTITUTE applied to SUBSTITUTE runs the child SUB on the parents. though this is probably wrong in the generic case
|
||||
if x.op is Ops.SUBSTITUTE:
|
||||
sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:])
|
||||
sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:])
|
||||
return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v))
|
||||
# here we actually do the SUBSTITUTE
|
||||
if x in keys: return values[keys.index(x)]
|
||||
# we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size)
|
||||
x_ranges = x.ranges
|
||||
new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)}
|
||||
# if there's no SUBSTITUTEs left, we can just return x
|
||||
if len(new_kv) == 0: return x
|
||||
# then we add SUBSTITUTE to all parents
|
||||
uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values()))
|
||||
return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src]))
|
||||
pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
@@ -504,8 +482,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
||||
# TODO: can you substitute and remove costly buffers at the same time?
|
||||
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
@@ -524,13 +500,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
assign_rep: dict[UOp, UOp] = {}
|
||||
for u in tsink.toposort():
|
||||
if u.op is not Ops.ASSIGN: continue
|
||||
if u.op is not Ops.AFTER: continue
|
||||
kernel_assign[u.buf_uop] = u
|
||||
for s in u.src[1].src:
|
||||
# TODO: this is probably broken for MSELECT/MSTACK
|
||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
||||
if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()):
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
|
||||
|
||||
@@ -249,9 +249,9 @@ class Tensor(MathTrait):
|
||||
self.kernelize(*lst)
|
||||
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# remove all ASSIGNs, after scheduling, the tensors are just buffers
|
||||
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
|
||||
_apply_map_to_tensors(remove_assign_map, name="Remove Assigns")
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.AFTER}
|
||||
_apply_map_to_tensors(remove_assign_map, name="Remove After")
|
||||
|
||||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(sink)
|
||||
|
||||
@@ -12,19 +12,18 @@ class Ops(FastEnum):
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
||||
SENTINEL = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
AFTER = auto()
|
||||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
||||
# create buffer
|
||||
BUFFERIZE = auto()
|
||||
SUBSTITUTE = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
||||
|
||||
# blocks in linearizer (only used there)
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
||||
|
||||
# movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
@@ -67,7 +66,7 @@ class Ops(FastEnum):
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
# consts. VCONST is a vectorized const
|
||||
VCONST = auto(); CONST = auto() # noqa: E702
|
||||
@@ -91,7 +90,6 @@ class GroupOp:
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}
|
||||
|
||||
@@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def ptrdtype(self) -> PtrDType:
|
||||
if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType")
|
||||
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
|
||||
return self.dtype
|
||||
|
||||
# *** uop shape stuff ***
|
||||
@@ -179,7 +179,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
|
||||
return None
|
||||
|
||||
# some ops init the shape
|
||||
@@ -190,7 +190,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE: return self.src[0]._shape
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER | Ops.END:
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.KERNEL: return self.arg.ast._shape
|
||||
@@ -275,6 +276,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
|
||||
for s in UOp.sink(*self.src[range_start[self.op]:]).ranges:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op is Ops.END:
|
||||
for s in self.src[self.arg:]: ret.update(s.ranges)
|
||||
for s in UOp.sink(*self.src[:self.arg]).ranges:
|
||||
if s in ret: del ret[s]
|
||||
else:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
return ret
|
||||
@@ -284,6 +289,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
return self._ranges
|
||||
|
||||
@functools.cached_property
|
||||
def ended_ranges(self):
|
||||
match self.op:
|
||||
case Ops.REDUCE: return self.src[1:]
|
||||
case Ops.END: return self.src[:self.arg]
|
||||
case _: raise RuntimeError(f"{self.op} doesn't end ranges")
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self, tracked=False, full_symbolic=True):
|
||||
@@ -324,7 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx): return self.index(idx)
|
||||
def __getitem__(self, *idx): return self.index(*idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
|
||||
@@ -349,6 +361,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
|
||||
def end(self, *src:UOp, ends:Sequence[UOp]):
|
||||
if len(ends) == 0: return self
|
||||
return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends))
|
||||
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def alu(self, op, *src:UOp, **kwargs):
|
||||
@@ -525,6 +541,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def _device(self) -> str|tuple[str, ...]|None:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
if self.op is Ops.BUFFERIZE: return self.arg.device
|
||||
if self.op is Ops.AFTER: return self.src[0]._device
|
||||
if self.op is Ops.MSELECT:
|
||||
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
|
||||
return self.src[0].device[self.arg]
|
||||
@@ -538,8 +555,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.BUFFER: return self
|
||||
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
|
||||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
assert self.op is Ops.AFTER, f"must be AFTER {self.op}"
|
||||
return self.src[0].buf_uop.base
|
||||
|
||||
def as_buf(self) -> UOp:
|
||||
if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)
|
||||
@@ -810,6 +827,8 @@ class UPat(MathTrait):
|
||||
@staticmethod
|
||||
def any(*src): return UPatAny(src=src)
|
||||
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
|
||||
def or_after(self, name:str|None=None):
|
||||
return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True))
|
||||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
@@ -836,7 +855,6 @@ class UPat(MathTrait):
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
|
||||
def or_broadcasted(self, **kwargs): return UPat.any(self, self.broadcast(**kwargs))
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
@@ -1091,7 +1109,7 @@ class RewriteContext:
|
||||
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
||||
except BottomUpGate:
|
||||
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
|
||||
self.replace[n] = new_n
|
||||
self.replace[n] = unwrap(test_n)
|
||||
continue
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src):
|
||||
@@ -1171,6 +1189,7 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
# TODO: this is only triggering if they are all casts, correct?
|
||||
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
@@ -66,8 +66,8 @@ buffer_spec = PatternMatcher([
|
||||
])
|
||||
|
||||
assign_spec = PatternMatcher([
|
||||
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
@@ -111,6 +111,9 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
|
||||
# REDUCE with an outerworld range
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True)
|
||||
])
|
||||
|
||||
# ***** uop type spec *****
|
||||
@@ -156,12 +159,16 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# allow AFTER on buffers
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
@@ -171,18 +178,14 @@ spec = PatternMatcher([
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD on STORE
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
|
||||
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
|
||||
|
||||
# STORE takes a <bufidx, val, gate?>
|
||||
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
|
||||
# STORE takes a <bufidx, val, ranges...>
|
||||
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
@@ -193,7 +196,7 @@ spec = PatternMatcher([
|
||||
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
||||
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
(UPat(Ops.END, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
@@ -201,9 +204,8 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier?>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True),
|
||||
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
@@ -234,9 +236,6 @@ full_spec = PatternMatcher([
|
||||
# SENTINEL should never be in the graph
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
|
||||
# allow any SUBSTITUTE
|
||||
(UPat(Ops.SUBSTITUTE), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
# where on index in rhs position is fine
|
||||
@@ -269,7 +268,7 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat((Ops.BLOCKSTART, Ops.BLOCK, Ops.BLOCKFINAL, Ops.BLOCKEND, Ops.KERNEL), dtype=dtypes.void), lambda: True),
|
||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# allow index dtype on a restricted set of UOps
|
||||
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
|
||||
@@ -283,6 +282,8 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR), lambda: True),
|
||||
# reshape on STORE
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
# allow any AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
])+tensor_uop_spec+spec
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
@@ -377,6 +377,13 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# END is only on RANGES
|
||||
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>tinygrad viz</title>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
|
||||
<script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
|
||||
@@ -38,24 +39,37 @@
|
||||
::-webkit-scrollbar-thumb { background: #686977; }
|
||||
a {
|
||||
color: #4a90e2;
|
||||
text-decoration: underline;
|
||||
cursor: pointer;
|
||||
}
|
||||
ul {
|
||||
padding: 0;
|
||||
opacity: 0.6;
|
||||
white-space: nowrap;
|
||||
cursor: pointer;
|
||||
}
|
||||
ul.active {
|
||||
ul > p {
|
||||
opacity: 0.6;
|
||||
}
|
||||
ul.active > p {
|
||||
opacity: 1;
|
||||
}
|
||||
ul > ul {
|
||||
display: none;
|
||||
margin-left: 6px;
|
||||
}
|
||||
ul.has-children > p::before {
|
||||
content:"▸ ";
|
||||
}
|
||||
ul.has-children.expanded > p::before {
|
||||
content:"▾ ";
|
||||
}
|
||||
ul.expanded > ul {
|
||||
display: block;
|
||||
}
|
||||
ul.disabled {
|
||||
ul.disabled > p {
|
||||
opacity: 0.4;
|
||||
}
|
||||
ul.disabled {
|
||||
pointer-events: none;
|
||||
}
|
||||
label {
|
||||
@@ -137,7 +151,7 @@
|
||||
.metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * {
|
||||
margin-top: 12px;
|
||||
}
|
||||
.ctx-list > ul > * + * {
|
||||
ul > * + * {
|
||||
margin-top: 4px;
|
||||
}
|
||||
.graph {
|
||||
|
||||
@@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) {
|
||||
if (parents == null && children == null) return;
|
||||
const src = [...parents, ...children, d.id];
|
||||
nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id));
|
||||
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
|
||||
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
|
||||
d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath");
|
||||
d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port");
|
||||
e.stopPropagation();
|
||||
@@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) {
|
||||
}).selectAll("text").data(d => {
|
||||
const ret = [[]];
|
||||
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
|
||||
for (const [i, l] of st.split("\n").entries()) {
|
||||
if (i > 0) ret.push([]);
|
||||
ret.at(-1).push({ st:l, color });
|
||||
}
|
||||
const lines = st.split("\n");
|
||||
ret.at(-1).push({ st:lines[0], color });
|
||||
for (let i=1; i<lines.length; i++) ret.push([{ st:lines[i], color }]);
|
||||
}
|
||||
return [ret];
|
||||
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
|
||||
@@ -174,10 +173,16 @@ function tabulate(rows) {
|
||||
return root;
|
||||
}
|
||||
|
||||
var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity;
|
||||
var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity, shapeMetadata = new Map();
|
||||
function focusShape(shape) {
|
||||
saveToHistory({ shape:focusedShape });
|
||||
focusedShape = shape?.key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
|
||||
return document.querySelector(".metadata").replaceChildren(shapeMetadata.get(focusedShape) ?? "");
|
||||
}
|
||||
|
||||
async function renderProfiler() {
|
||||
displayGraph("profiler");
|
||||
d3.select(".metadata").node().replaceChildren(focusedShape?.html ?? "");
|
||||
d3.select(".metadata").node().replaceChildren(shapeMetadata.get(focusedShape) ?? "");
|
||||
// layout once!
|
||||
if (data != null) return updateProgress({ start:false });
|
||||
const profiler = d3.select(".profiler").html("");
|
||||
@@ -243,15 +248,17 @@ async function renderProfiler() {
|
||||
}
|
||||
const html = document.createElement("div");
|
||||
html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node());
|
||||
const argsDiv = document.createElement("div"); argsDiv.id = "args"; html.appendChild(document.createElement("br")); html.appendChild(argsDiv);
|
||||
if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info;
|
||||
if (shapeRef != null) {
|
||||
const p = html.appendChild(document.createElement("p"));
|
||||
p.innerText = "\nView Codegen Rewrite"; p.style.cursor = "pointer";
|
||||
p.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step);
|
||||
const a = html.appendChild(document.createElement("a"));
|
||||
a.innerText = "\nView codegen rewrite";
|
||||
a.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step);
|
||||
}
|
||||
// tiny device events go straight to the rewrite rule
|
||||
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
|
||||
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), html, key, ...shapeRef };
|
||||
if (key != null) shapeMetadata.set(key, html);
|
||||
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, ...shapeRef };
|
||||
if (e.key != null) shapeMap.set(e.key, arg);
|
||||
// offset y by depth
|
||||
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor });
|
||||
@@ -295,17 +302,30 @@ async function renderProfiler() {
|
||||
const rows = [["DType", dtype], ["Len", formatUnit(sz)], ["Size", formatUnit(nbytes, "B")], ["Lifetime", formatTime(dur)]];
|
||||
if (users != null) rows.push(["Users", users.length]);
|
||||
const info = html.appendChild(tabulate(rows).node());
|
||||
const arg = {tooltipText:info.outerHTML, key:`${k}-${num}`};
|
||||
for (let u=0; u<users?.length; u++) {
|
||||
const p = html.appendChild(document.createElement("p")); p.style.marginTop = "4px";
|
||||
const { repr, num, mode, shape } = users[u]; p.appendChild(colored(`[${u}] ${repr} ${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`));
|
||||
const { repr, num, mode, shape } = users[u];
|
||||
const bufInfo = `${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`
|
||||
p.appendChild(colored(`[${u}] ${repr} ${bufInfo}`));
|
||||
const metadata = shape?.tooltipText?.split("\n").at(-1);
|
||||
if (metadata != null) p.appendChild(document.createElement("span")).innerText = "\n"+metadata;
|
||||
if (shape != null) {
|
||||
p.style.cursor = "pointer";
|
||||
p.onclick = () => focusShape(shape);
|
||||
const args = shapeMetadata.get(shape.key).querySelector("#args");
|
||||
const bufArg = d3.create("p").text(`${bufInfo} ${rows[2][1]}`).style("cursor", "pointer").style("margin-top", "4px").on("click", () => {
|
||||
const device = document.getElementById(k);
|
||||
if (!isExpanded(device)) device.click();
|
||||
focusShape(arg);
|
||||
}).node();
|
||||
bufArg.dataset.num = num;
|
||||
let before = null;
|
||||
for (const c of args.children) { if (+c.dataset.num > num) { before = c; break; } }
|
||||
args.insertBefore(bufArg, before);
|
||||
}
|
||||
}
|
||||
const arg = {tooltipText:info.outerHTML, html, key:`${k}-${num}`};
|
||||
shapeMetadata.set(arg.key, html)
|
||||
shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) });
|
||||
}
|
||||
// generic polygon merger
|
||||
@@ -338,6 +358,7 @@ async function renderProfiler() {
|
||||
else if (tid === focusedDevice) { track.shapes = track.views[0]; offset += rescaleTrack(track, tid, 1/track.scaleFactor); }
|
||||
}
|
||||
data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null;
|
||||
toggleCls(document.getElementById(focusedDevice), document.getElementById(newFocus), "expanded");
|
||||
focusedDevice = newFocus;
|
||||
return resize();
|
||||
});
|
||||
@@ -396,7 +417,7 @@ async function renderProfiler() {
|
||||
lw += e.label[li].width;
|
||||
}
|
||||
}
|
||||
if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push([p, pcolor]); }
|
||||
if (focusedShape != null && e.arg?.key === focusedShape) { paths.push([p, pcolor]); }
|
||||
}
|
||||
}
|
||||
// draw axes
|
||||
@@ -463,15 +484,11 @@ async function renderProfiler() {
|
||||
}
|
||||
}
|
||||
|
||||
function focusShape(shape) {
|
||||
focusedShape = shape; render(zoomLevel);
|
||||
return document.querySelector(".metadata").replaceChildren(shape?.html ?? "");
|
||||
}
|
||||
canvas.addEventListener("click", e => {
|
||||
e.preventDefault();
|
||||
const foundRect = findRectAtPosition(e.clientX, e.clientY);
|
||||
if (foundRect?.step != null && foundRect?.key == null) { return setCtxWithHistory(foundRect.ctx, foundRect.step); }
|
||||
if (foundRect?.key != focusedShape?.key) { focusShape(foundRect); }
|
||||
if (foundRect?.key != focusedShape) { focusShape(foundRect); }
|
||||
});
|
||||
|
||||
canvas.addEventListener("mousemove", e => {
|
||||
@@ -530,10 +547,10 @@ function codeBlock(st, language, { loc, wrap }={}) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
function setActive(e) {
|
||||
if (e == null) return;
|
||||
e.classList.add("active");
|
||||
requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" }));
|
||||
function toggleCls(prev, next, cls, value) {
|
||||
prev?.classList.remove(cls);
|
||||
next?.classList.toggle(cls, value ?? true);
|
||||
requestAnimationFrame(() => next?.scrollIntoView({ behavior: "auto", block: "nearest" }));
|
||||
}
|
||||
|
||||
// ** hljs extra definitions for UOps and float4
|
||||
@@ -563,31 +580,41 @@ const evtSources = [];
|
||||
// context: collection of steps
|
||||
const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false};
|
||||
function setState(ns) {
|
||||
const { currentCtx:prevCtx, currentStep:prevStep } = state;
|
||||
const { ctx:prevCtx, step:prevStep } = select(state.currentCtx, state.currentStep);
|
||||
Object.assign(state, ns);
|
||||
// update element styles if needed
|
||||
document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps);
|
||||
if (state.currentCtx !== prevCtx) {
|
||||
document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded");
|
||||
setActive(document.getElementById(`ctx-${state.currentCtx}`));
|
||||
}
|
||||
if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) {
|
||||
document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active");
|
||||
setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`));
|
||||
const { ctx, step } = select(state.currentCtx, state.currentStep);
|
||||
toggleCls(prevCtx, ctx, "expanded", state.expandSteps);
|
||||
if (ctx?.id !== prevCtx?.id) toggleCls(prevCtx, ctx, "active");
|
||||
if (ctx?.id !== prevCtx?.id || step?.id !== prevStep?.id) {
|
||||
toggleCls(prevStep, step, "active");
|
||||
// walk the tree back until all parents expanded so that the child is visible
|
||||
let e = step;
|
||||
while (e?.parentElement?.id.startsWith("step")) {
|
||||
e.parentElement.classList.add("expanded");
|
||||
e = e.parentElement;
|
||||
}
|
||||
}
|
||||
// re-render
|
||||
main();
|
||||
}
|
||||
|
||||
const getSubrewrites = (ul) => ul.querySelectorAll(":scope > ul");
|
||||
|
||||
function saveToHistory(ns) {
|
||||
// NOTE: browser does a structured clone, passing a mutable object is safe.
|
||||
history.replaceState(ns, "");
|
||||
history.pushState(ns, "");
|
||||
}
|
||||
|
||||
// set a new context and keep the old one in browser history
|
||||
function setCtxWithHistory(newCtx, step=0) {
|
||||
// NOTE: browser does a structured clone, passing a mutable object is safe.
|
||||
history.replaceState(state, "");
|
||||
history.pushState(state, "");
|
||||
saveToHistory(state);
|
||||
setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 });
|
||||
}
|
||||
|
||||
window.addEventListener("popstate", (e) => {
|
||||
if (e.state?.shape != null) return focusShape({ key:e.state?.shape });
|
||||
if (e.state != null) setState(e.state);
|
||||
});
|
||||
|
||||
@@ -605,15 +632,25 @@ async function main() {
|
||||
p.onclick = () => {
|
||||
setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 });
|
||||
}
|
||||
const stack = []; let list = ul;
|
||||
for (const [j,u] of steps.entries()) {
|
||||
const inner = ul.appendChild(document.createElement("ul"));
|
||||
inner.id = `step-${i}-${j}`;
|
||||
inner.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : '');
|
||||
inner.style.marginLeft = `${8*u.depth}px`;
|
||||
inner.onclick = (e) => {
|
||||
while (stack.length && stack.at(-1).depth >= u.depth) stack.pop();
|
||||
const list = stack.length > 0 ? stack.at(-1).li : ul;
|
||||
u.li = list.appendChild(document.createElement("ul"));
|
||||
u.li.id = `step-${i}-${j}`;
|
||||
const p = u.li.appendChild(document.createElement("p"));
|
||||
p.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : '');
|
||||
p.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
setState({ currentStep:j, currentCtx:i, currentRewrite:0 });
|
||||
const subrewrites = getSubrewrites(e.currentTarget.parentElement);
|
||||
if (subrewrites.length) { e.currentTarget.parentElement.classList.toggle("expanded"); }
|
||||
setState({ currentStep:j, currentCtx:i });
|
||||
}
|
||||
stack.push(u);
|
||||
}
|
||||
for (const l of ul.querySelectorAll("ul > ul > p")) {
|
||||
const subrewrites = getSubrewrites(l.parentElement);
|
||||
if (subrewrites.length > 0) { l.innerText += ` (${subrewrites.length})`; l.parentElement.classList.add("has-children"); }
|
||||
}
|
||||
}
|
||||
return setState({ currentCtx:-1 });
|
||||
@@ -706,8 +743,9 @@ async function main() {
|
||||
rewriteList.className = "rewrite-list";
|
||||
for (let s=0; s<=step.match_count; s++) {
|
||||
const ul = rewriteList.appendChild(document.createElement("ul"));
|
||||
ul.innerText = s;
|
||||
ul.id = `rewrite-${s}`;
|
||||
const p = ul.appendChild(document.createElement("p"));
|
||||
p.innerText = s;
|
||||
ul.onclick = () => setState({ currentRewrite:s });
|
||||
ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : "";
|
||||
if (s > 0 && s === currentRewrite) {
|
||||
@@ -762,22 +800,32 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid
|
||||
|
||||
// **** keyboard shortcuts
|
||||
|
||||
const select = (ctx, step) => ({ ctx:document.getElementById(`ctx-${ctx}`), step:document.getElementById(`step-${ctx}-${step}`) });
|
||||
const deselect = (element) => {
|
||||
const parts = element?.id.split("-").map(Number);
|
||||
return element?.id.startsWith("ctx") ? { ctx:parts[1], step:null } : element?.id.startsWith("step") ? {ctx:parts[1], step:parts[2]} : {};
|
||||
}
|
||||
const isExpanded = (el) => el?.classList.contains("expanded");
|
||||
|
||||
document.addEventListener("keydown", (event) => {
|
||||
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
|
||||
// up and down change the step or context from the list
|
||||
const changeStep = expandSteps && ctxs[currentCtx].steps?.length;
|
||||
const { step, ctx } = select(currentCtx, currentStep);
|
||||
if (event.key == "ArrowUp") {
|
||||
event.preventDefault();
|
||||
if (changeStep) {
|
||||
return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) });
|
||||
let prev = deselect(step.previousElementSibling);
|
||||
if (prev.step == null && isExpanded(step.parentElement)) prev = deselect(step.parentElement);
|
||||
return prev.step != null && !isExpanded(step) && setState({ currentRewrite:0, currentStep:prev.step });
|
||||
}
|
||||
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false });
|
||||
}
|
||||
if (event.key == "ArrowDown") {
|
||||
event.preventDefault();
|
||||
if (changeStep) {
|
||||
const totalUOps = ctxs[currentCtx].steps.length-1;
|
||||
return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) });
|
||||
const next = deselect(isExpanded(step) ? step.children[1] : step.nextElementSibling);
|
||||
return next.step != null && setState({ currentRewrite:0, currentStep:next.step });
|
||||
}
|
||||
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false });
|
||||
}
|
||||
@@ -787,6 +835,7 @@ document.addEventListener("keydown", (event) => {
|
||||
if (currentCtx === -1) {
|
||||
return setState({ currentCtx:0, expandSteps:true });
|
||||
}
|
||||
if (expandSteps && getSubrewrites(step).length) return step.children[0].click();
|
||||
return setState({ expandSteps:!expandSteps });
|
||||
}
|
||||
// left and right go through rewrites in a single UOp
|
||||
|
||||
@@ -17,10 +17,10 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00"}
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
@@ -33,8 +33,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]:
|
||||
steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc),
|
||||
"query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
|
||||
if isinstance(k.ret, ProgramSpec):
|
||||
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
|
||||
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
|
||||
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src", "depth":0})
|
||||
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm", "depth":0})
|
||||
for key in k.keys: ref_map[key] = i
|
||||
ret.append({"name":k.display_name, "steps":steps})
|
||||
return ret
|
||||
@@ -71,7 +71,7 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src):
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
if x in excluded:
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
@@ -82,6 +82,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
label += f"\n{u.render()}"
|
||||
if u.op is Ops.END:
|
||||
label += "\n"+' '.join([f"{colored(u.src[i].arg[0], axis_colors[u.src[i].arg[-1]])}({u.src[i].vmax+1})" for i in range(u.arg)])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
@@ -98,9 +100,11 @@ def _reconstruct(a:int):
|
||||
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
|
||||
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
|
||||
ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"})
|
||||
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None,
|
||||
"diff":None, "upat":None}
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
|
||||
ignore_indexing = trace.keys[i].display_name.startswith("Schedule") and not (ctx.name in {"kernel split"} or \
|
||||
any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,)))
|
||||
yield {"graph":uop_to_json(next_sink, ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
|
||||
Reference in New Issue
Block a user