Merge branch 'master' into multioutput

This commit is contained in:
George Hotz
2025-10-21 18:16:34 +08:00
committed by GitHub
49 changed files with 496 additions and 624 deletions

View File

@@ -302,4 +302,4 @@ runs:
- name: Install mesa (macOS)
if: inputs.mesa == 'true' && runner.os == 'macOS'
shell: bash
run: brew install sirhcm/tinymesa/tinymesa
run: brew install sirhcm/tinymesa/tinymesa_cpu

View File

@@ -626,11 +626,11 @@ jobs:
- name: benchmark openpilot 0.9.9 dmonitoring
run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx
- name: openpilot compile3 0.10.1 driving_vision
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: openpilot compile3 0.10.1 driving_policy
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.1 dmonitoring
run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx
- name: benchmark MobileNetV2 on DSP
run: |
# generate quantized weights

View File

@@ -520,17 +520,17 @@ generate_mesa() {
LVP_NIR_OPTIONS=$(./extra/mesa/lvp_nir_options.sh $MESA_SRC)
fixup $BASE/mesa.py
patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "f'{brew_prefix()}/lib/libtinymesa_cpu.dylib'"
patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "brew_path('tinymesa_cpu')" "brew_path('tinymesa')"
echo "lvp_nir_options = gzip.decompress(base64.b64decode('$LVP_NIR_OPTIONS'))" >> $BASE/mesa.py
cat <<EOF | sed -i "/import ctypes.*/r /dev/stdin" $BASE/mesa.py
def brew_prefix():
try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip()
except Exception: return ''
def brew_path(nm):
try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib"
except Exception: return 'failed'
EOF
sed -i "/in_dll/s/.*/try: &\nexcept AttributeError: pass/" $BASE/mesa.py
sed -i "/in_dll/s/.*/try: &\nexcept (AttributeError, ValueError): pass/" $BASE/mesa.py
sed -i "s/import ctypes/import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers/" $BASE/mesa.py
sed -i "s/ctypes.CDLL('.\+')/(dll := _try_dlopen_tinymesa_cpu())/" $BASE/mesa.py
echo "def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py
echo "def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py
sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/mesa.py
# bitfield bug in clang2py
sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/mesa.py

View File

@@ -121,6 +121,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file, tol):
print("test vs onnx passed")
return timings
def bench(run, inputs):
from extra.bench_log import WallTimeEvent, BenchEvent
for _ in range(10):
with WallTimeEvent(BenchEvent.STEP):
run(**inputs).numpy()
if __name__ == "__main__":
onnx_file = fetch(OPENPILOT_MODEL)
inputs, outputs = compile(onnx_file)
@@ -131,3 +137,5 @@ if __name__ == "__main__":
if not getenv("FLOAT16"):
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
if getenv("BENCHMARK_LOG", ""):
bench(pickle_loaded, inputs)

View File

@@ -99,6 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
N = 1
@@ -112,19 +113,22 @@ if __name__ == "__main__":
model = StableDiffusionV2(**params)
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
weights_fn = args.weights_fn
if not weights_fn:
weights_url = args.weights_url if args.weights_url else default_weights_url
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
load_state_dict(model, safe_load(weights_fn), strict=False)
if not args.fakeweights:
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
weights_fn = args.weights_fn
if not weights_fn:
weights_url = args.weights_url if args.weights_url else default_weights_url
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
load_state_dict(model, safe_load(weights_fn), strict=False)
if args.fp16:
for k,v in get_state_dict(model).items():
if k.startswith("model"):
v.replace(v.cast(dtypes.float16).realize())
v.replace(v.cast(dtypes.float16))
Tensor.realize(*get_state_dict(model).values())
c = { "crossattn": model.cond_stage_model(args.prompt) }
uc = { "crossattn": model.cond_stage_model("") }

View File

@@ -263,14 +263,16 @@ if __name__ == "__main__":
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
model = StableDiffusion()
# load in weights
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
if args.fp16:
for k,v in get_state_dict(model).items():

View File

@@ -2,7 +2,8 @@ from extra.models.resnet import ResNet50
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.uop.ops import Ops
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer
from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites
from tinygrad.codegen.late.control_flow import linearize
from tinygrad.uop.spec import type_verify
if __name__ == "__main__":
@@ -39,7 +40,7 @@ if __name__ == "__main__":
with Timing("***** model linearize in "):
uops_line = []
for u in rewritten_uops:
uops_line.append(apply_rewrites(u, rewrites_for_linearizer))
uops_line.append(linearize(u))
with Timing("***** model verify in "):
for u in uops_line: type_verify(u.arg.lst)
print(sum(len(u.arg.lst) for u in uops_line))
for u in uops_line: type_verify(u)
print(sum(len(u) for u in uops_line))

View File

@@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase):
# theoretical is nv_tflops=165, amd_tflops=123
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=65)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=60)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=115, amd_tflops=60)
# theoretical is nv_gbs=1008, amd_gbs=960
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=820, amd_gbs=750)
if __name__ == '__main__':
unittest.main()

View File

@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase):
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = get_program(ast, opts=opt).uops
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
for u in uops:
if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG:
@@ -214,8 +214,8 @@ class TestLinearizer(unittest.TestCase):
else:
assert u.src[1].op in GroupOp.ALU
assert begin_range < uops.index(u) < end_range
# children of STORE are placed after ENDRANGE
if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src):
# children of END are placed after ENDRANGE
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
def test_grouped_dims(self):
@@ -400,7 +400,7 @@ class TestLinearizer(unittest.TestCase):
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1
assert len([u for u in uops if u.op is Ops.IF and u.src[1] == barrier]) == 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")

View File

@@ -2602,6 +2602,7 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
@unittest.skipIf(Device.DEFAULT == "AMD" and CI, "remu failure?")
def test_avg_pool3d_failure(self):
with Context(NOOPT=0):
helper_test_op([(1,1,16,16,16)],

View File

@@ -1,7 +1,9 @@
import unittest
from tinygrad import Tensor, nn
from tinygrad.helpers import Context, GlobalCounters, CI, CPU_LVP, getenv
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import Context, GlobalCounters, CI, getenv
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
class TestRangeifyAssign(unittest.TestCase):
def test_assign_permuted(self):
@@ -40,7 +42,7 @@ elif getenv("BIG") > 0:
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
@unittest.skipIf(CPU_LVP, "broken in LVP")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
class TestPcontig(unittest.TestCase):
def test_flash_attention_bw(self):
def fa_bw():
@@ -62,7 +64,7 @@ class TestPcontig(unittest.TestCase):
Tensor.realize(*ret)
return ret
with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2):
with Context(PCONTIG=2, DEBUG=2):
grads = fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")

View File

@@ -3,7 +3,7 @@ import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.realize import run_schedule
from tinygrad.uop.ops import Ops, UOp, UPat
from tinygrad.uop.ops import UOp
from tinygrad.helpers import SPLIT_REDUCEOP
class TestTensorUOp(unittest.TestCase):
@@ -93,7 +93,6 @@ class TestTensorUOp(unittest.TestCase):
out.realize()
self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist())
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, allow_any_len=True, src=(UPat(), UPat((Ops.REDUCE_AXIS, Ops.REDUCE))))))
@unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP")
class TestReduceOp(unittest.TestCase):
def test_no_split_reduce_kernel(self):
@@ -101,23 +100,18 @@ class TestReduceOp(unittest.TestCase):
a = a.sum()
sched = a.schedule()
assert len(sched) == 1
assert reduce_kernel.match(sched[0].ast, {})
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 2
for s in sched:
assert reduce_kernel.match(s.ast, {})
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 2
for s in sched:
assert reduce_kernel.match(s.ast, {})
if __name__ == "__main__":
unittest.main()

View File

@@ -665,19 +665,6 @@ class TestUOpGraph(unittest.TestCase):
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp.range(2, 0)
r2 = UOp.range(2, 1)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is Ops.RANGE]
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
@track_rewrites()
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
@@ -845,8 +832,6 @@ class TestIFUOps(unittest.TestCase):
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
def test_expand_ifs_one_gate(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
@@ -863,8 +848,6 @@ class TestIFUOps(unittest.TestCase):
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
@@ -879,8 +862,6 @@ class TestIFUOps(unittest.TestCase):
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
class TestUOpTags(unittest.TestCase):
def test_inc_by_one(self):

View File

@@ -1,76 +0,0 @@
import unittest, random
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import print_uops, UOp, Ops
from tinygrad.codegen.late.linearize import block_reorder
from tinygrad.renderer.cstyle import OpenCLRenderer
def is_toposorted(lst:list[UOp]):
seen = set()
for u in lst:
if any(p not in seen for p in u.src): return False
seen.add(u)
return True
class TestBlockReorder(unittest.TestCase):
def _test_randomize(self, golden:list[UOp]):
# test random order is always same
for _ in range(50):
# shuffle and form a valid toposort
lst = golden[:]
random.shuffle(lst)
topolst = []
for u in lst:
for p in u.toposort():
if p not in topolst: topolst.append(p)
assert is_toposorted(topolst)
for x,y in zip(golden, this_order:=block_reorder(topolst)):
if x is not y:
print_uops(golden)
print_uops(this_order)
self.assertIs(x, y)
def _test_render(self, golden:list[UOp]):
return OpenCLRenderer().render(golden)
def test_loads(self):
a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0)
b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2)
v1 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx0")
v2 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx1")
v1 = v1*27
v2 = v2*4
loads = [
a.index(v1).load(dtype=dtypes.float),
a.index(v1+1).load(dtype=dtypes.float),
a.index(v1+2).load(dtype=dtypes.float),
a.index(v1+3).load(dtype=dtypes.float),
b.index(v2).load(dtype=dtypes.float),
b.index(v2+1).load(dtype=dtypes.float),
b.index(v2+2).load(dtype=dtypes.float),
b.index(v2+3).load(dtype=dtypes.float)]
#random.shuffle(loads)
sink = c.store(sum(loads)).sink()
# determine golden order
golden = block_reorder(list(sink.toposort()))
# render for test
print(self._test_render(golden))
#print_uops(golden)
# assert the loads are in this order
self.assertListEqual([g.src[0].src[1].render() for g in golden if g.op is Ops.LOAD],
['(gidx1*4)', '((gidx1*4)+1)', '((gidx1*4)+2)', '((gidx1*4)+3)',
'(gidx0*27)', '((gidx0*27)+1)', '((gidx0*27)+2)', '((gidx0*27)+3)'])
# assert math is after loads
first_math = [i for i,g in enumerate(golden) if g.op is Ops.ADD and g.dtype == dtypes.float][0]
assert not any(x.op is Ops.LOAD for x in golden[first_math:])
# confirm the sort is stable
self._test_randomize(golden)
if __name__ == '__main__':
unittest.main()

View File

@@ -20,8 +20,8 @@ class TestKernelize(unittest.TestCase):
self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2)
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
# input Tensor and user contiguous kernelize
self.assertIs(a0.uop.base.op, Ops.ASSIGN)
self.assertIs(a.uop.base.op, Ops.ASSIGN)
self.assertIs(a0.uop.base.op, Ops.AFTER)
self.assertIs(a.uop.base.op, Ops.AFTER)
def test_two_reduce_w_add(self):
a = Tensor.ones(16,16).contiguous()
@@ -31,7 +31,7 @@ class TestKernelize(unittest.TestCase):
# NOTE: the +1 is fused with a1, so a1 is not kernelized
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
# the input to the REDUCE_AXIS is an ASSIGN though
self.assertIs(a1.uop.base.src[0].base.op, Ops.ASSIGN)
self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER)
if __name__ == '__main__':
unittest.main()

View File

@@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.symbolic import simplify_valid
from tinygrad.helpers import Context
from .test_uop_symbolic import check_uop_against_string
from test.unit.test_uop_symbolic import check_uop_against_string
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (

View File

@@ -148,6 +148,14 @@ class TestViz(BaseTestViz):
a2 = uop_to_json(a)[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta")
src = [Tensor.empty(1).uop for _ in range(10)]
a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg)
exec_rewrite(a, [PatternMatcher([])])
a2 = next(get_viz_details(0, 0))["graph"][id(a)]
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
b = a.replace(op=Ops.CONST)

View File

@@ -14,10 +14,10 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.codegen.opt.postrange import pm_postrange_opt
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize
@dataclass
class RewriteStep:
@@ -30,12 +30,6 @@ class RewriteStep:
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
rewrites_for_linearizer = [
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
RewriteStep(block_merge, name="Linearizer: Merge Blocks"),
RewriteStep(pm_finalize, name="Linearizer: Finalize")]
def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]:
# cache with the values of the context vars
return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value)
@@ -101,11 +95,15 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
# return the list (with optional linearizer)
return ret + (rewrites_for_linearizer if linearizer else [])
# this was the linearizer
ret.append(RewriteStep(pm_merge_ends, name="merge ends"))
ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True))
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
# return the list
return ret
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
"""
@@ -119,6 +117,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
Linear program in UOps.
"""
lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst)
lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None))
if __debug__: type_verify(lst)
return lst

View File

@@ -1,4 +1,4 @@
import math
import math, functools, operator
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
from tinygrad.helpers import all_int, dedup, get_contraction
from tinygrad.dtype import dtypes
@@ -87,7 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp):
except ValueError: continue
return s.substitute(subs)
def add_barrier_and_if(buf:UOp, e:UOp):
# TODO: this is not generic
local_ranges = [x for x in e.ended_ranges if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE]
if len(local_ranges) == 0: return None
return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), e.barrier())))
pm_add_gpudims = PatternMatcher([
# add gpudims must be last
(UPat(Ops.SINK, name="s"), add_gpudims),
# add barrier and if
(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.END, name="e"))), add_barrier_and_if),
])

View File

@@ -0,0 +1,102 @@
import heapq
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
def linearize(u:UOp) -> list[UOp]:
lst = list(u.toposort())
in_this_block = set(lst)
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, int] = {}
# get local children and assign priorities
# NOTE: this requires the lst be locally toposorted
for u in reversed(lst):
in_degree[u] = 0
for s in u.src:
if s in in_this_block:
local_children[s].append(u)
in_degree[u] += 1
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [0] + [priorities[x] for x in local_children[u]]
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.BARRIER: priority.append(-1500)
# ranges are scheduled as late as possible so anything that can be outside is
#if u.op is Ops.RANGE: priority = [2000]
# move defines and consts to the top
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
priorities[u] = min(priority)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
# then force then to be toposorted in as close to the ideal order as possible
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
newlst = []
while heap:
newlst.append(u:=heapq.heappop(heap)[1])
for v in local_children[u]:
in_degree[v] -= 1
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
return newlst
class CFGContext:
def __init__(self, sink:UOp):
# there are 3 relationships between ranges:
# nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y
# independent, endrange y is not a dependency of endrange x
# everything is nested inside the sink
deps: dict[UOp, set[UOp]] = {}
nesting: dict[UOp, UOp] = {}
for u in sink.toposort():
deps[u] = set().union(*(deps[s] for s in u.src))
if u.op in (Ops.END, Ops.ENDIF, Ops.SINK):
nesting |= {x:u for x in deps[u] if x.op in (Ops.END, Ops.ENDIF) and (u.op is Ops.SINK or u.src[0] in deps[x]) and x not in nesting}
if u.op in (Ops.RANGE, Ops.END, Ops.IF, Ops.ENDIF): deps[u] |= {u}
self.edges: dict[UOp, UOp] = {}
siblings: dict[UOp, list[UOp]] = {}
for k,vv in nesting.items(): siblings.setdefault(vv, []).append(k)
for k,v in siblings.items():
# range/if that have dependencies on other siblings need to run after them
order = sorted(v, key=lambda x: len(deps[x].intersection(v)))
zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[0]] + order, order)
for x,y in zipped:
# TODO: is this check correct?
if y.src[0] not in x.backward_slice_with_self:
self.edges[y.src[0]] = x
pm_add_control_flow = PatternMatcher([
(UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
])
def do_merge_ends(s:UOp):
# NOTE: this can fail
stacked: dict[UOp, list[UOp]] = {}
dangling_ifs = []
for x in s.toposort():
if x.op in {Ops.END, Ops.ENDIF}:
assert x.op is not Ops.END or x.arg == 1, "ends must be single ends for linearizer"
stacked.setdefault(x.src[0], []).append(x)
if x.op is Ops.IF: dangling_ifs.append(x)
dangling_ifs = [x for x in dangling_ifs if x not in stacked]
replaces = {}
for k,v in stacked.items():
if len(v) == 1: continue
rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=x[0].arg)
for x in v: replaces[x] = rep
if not len(replaces) and not len(dangling_ifs): return None
ret = s.substitute(replaces)
if len(dangling_ifs):
assert len(dangling_ifs) == 1, "we only support 1 dangling if"
ret = ret.replace(src=(UOp(Ops.ENDIF, src=(dangling_ifs[0], *ret.src)),))
return ret
pm_merge_ends = PatternMatcher([
# for renderering and linearizing, all ends must end one loop
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
(UPat(Ops.SINK, name="s"), do_merge_ends),
])

View File

@@ -123,7 +123,7 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines).or_after(name="buf")), UPat.var("vec"))), expand_index),
# GEP after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
@@ -242,11 +242,13 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
devectorize = PatternMatcher([
# CAST after AFTER
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
])
pm_render = PatternMatcher([
@@ -266,9 +268,9 @@ pm_render = PatternMatcher([
UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
# gate any stores that aren't gated with ifs
# gate any stores that aren't gated with if/endif pairs
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
lambda store,idx: UOp(Ops.ENDIF, src=(uif:=UOp(Ops.IF, src=(idx.src[2],)), UOp(Ops.STORE, src=store.src[:2]+(uif,)+store.src[2:]))) if \
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
])
@@ -293,15 +295,17 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
# if we have a range
if len(reduce_range) != 0:
topo = inp.toposort()
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
ended_ranges = flatten([x.src[:x.arg] for x in topo if x.op is Ops.END])
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,))
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
acc.index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
if len(reduce_range) == 0: return ret
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(ends=reduce_range[::-1])).index(UOp.const(dtypes.int, 0)).load()
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN

View File

@@ -87,7 +87,7 @@ expander = PatternMatcher([
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
Ops.VECTORIZE, Ops.IF, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
@@ -145,8 +145,7 @@ def fix_group_for_reduce(x:UOp):
reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr]
buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop)
# gate with an if on the store + do the final reduce
buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf))
# do the final reduce (if/barrier are added in gpudims step)
return buf.reduce(*reduce_loop, arg=x.arg)
pm_pre_expander = PatternMatcher([

View File

@@ -1,243 +0,0 @@
from __future__ import annotations
import heapq
from collections import defaultdict
from dataclasses import dataclass, replace
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate
from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
def block_reorder(lst:list[UOp]) -> list[UOp]:
in_this_block = set(lst)
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, int] = {}
# get local children and assign priorities
# NOTE: this requires the lst be locally toposorted
for u in reversed(lst):
in_degree[u] = 0
for s in u.src:
if s in in_this_block:
local_children[s].append(u)
in_degree[u] += 1
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [0] + [priorities[x] for x in local_children[u]]
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.BARRIER: priority.append(-1500)
priorities[u] = min(priority)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
# then force then to be toposorted in as close to the ideal order as possible
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
newlst = []
while heap:
newlst.append(u:=heapq.heappop(heap)[1])
for v in local_children[u]:
in_degree[v] -= 1
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
return newlst
# ***** basic block *****
def disp(y:UOp) -> str:
if y.op is Ops.IF: return f'IF{id(y)}'
if y.op is Ops.RANGE: return str(y.arg)
return "<NONE>"
@dataclass(frozen=True, eq=False)
class BasicBlock:
lst: tuple[UOp, ...]
ctx: tuple[UOp, ...] = ()
end: UOp|None = None
cnt: int = 0
child_ctx: tuple[UOp, ...]|None = None
def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks")
def __repr__(self):
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
# ***** block context *****
@dataclass
class BlockContext:
child_count: dict[UOp, int]
block_ctxs: dict[UOp, tuple[UOp, ...]]
child_ctxs: dict[UOp, tuple[UOp, ...]]
def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u])
@staticmethod
def from_sink(sink:UOp) -> BlockContext:
# get children and all block contexts
ctx = BlockContext({}, {}, {})
for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL):
this_block_ctx: list[UOp] = []
ctx.child_count[u] = 0
# get children and accumulate the last_ctx
for s in u.src:
if s.op is Ops.SPECIAL: continue
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
ctx.child_count[s] += 1
this_block_ctx += ctx.last_ctx(s)
# save the block ctx. SINK never has anything
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
# RANGE/IF add to the next ctx
# STORE/ASSIGN subtract from the next ctx
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
return ctx
# ***** make blocks *****
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
ends_to_add = [z for z in new_ctx if z not in current_ctx]
while len(ends_to_add):
r:UOp = ends_to_add.pop(-1)
new_ctx = tuple([z for z in new_ctx if z is not r])
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
return base_block
def make_block_bottom_up(ctx:BlockContext, x:UOp):
if x.op is Ops.BLOCKSTART:
current_ctx, child_ctx = x.arg
lst = list(x.src)
child_count = 1
else:
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
lst = [x]
# count of times we've seen this block, or a seed for a new block if we can't merge it
unmergable: defaultdict[UOp, int] = defaultdict(int)
blockseeds = defaultdict(list)
# add the srcs of this to the frontier
# NOTE: things may be in here multiple times, that's okay
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
while len(frontier_nodes):
u = frontier_nodes.pop(0)
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
# count is correct
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
# block has same context, merge it, and put the srcs on the frontier
lst.append(u)
frontier_nodes.extend(u.src[::-1])
else:
# block has different context, add it to blockseeds
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
del unmergable[u]
else:
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
unmergable[u] += 1
# add unmergables to sources
srcs = []
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt
# add blockseeds, with blockends as needed
for (new_ctx, new_child_ctx), v in blockseeds.items():
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
lst = lst[::-1]
if BLOCK_REORDER: lst = block_reorder(lst)
bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
# we prevent the source of the SPECIAL from being linearized since its not part of the kernel
def raise_bottom_up_gate(): raise BottomUpGate()
block_create = PatternMatcher([
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
(UPat(Ops.SPECIAL), raise_bottom_up_gate)
])
# ***** blockend merging ****
def merge_blockends(sink:UOp) -> UOp|None:
# only run on the final BLOCK with the SINK in it
if sink.arg.lst[-1].op is not Ops.SINK: return None
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
blockends_to_arg: dict[UOp, list[UOp]] = {}
for be in sink.toposort():
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len(v) > 1:
bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
# NOTE: bb.ctx != u.arg.ctx can cause problems here
for u in v: new_forks[u] = out
if len(new_forks) == 0: return None
return sink.substitute(new_forks)
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
# ***** block merging ****
def merge_block(x:UOp):
unmergable_blocks, mergable_blocks = [], []
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
for y in x.src:
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
else: unmergable_blocks.append(y)
for k,v in mergable_dict.items():
if v == k.arg.cnt: mergable_blocks.append(k)
else: unmergable_blocks.extend([k]*v)
if len(mergable_blocks) == 0: return None
del mergable_dict
# create the block
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
def remove_blockend(x:UOp):
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]):
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
parent_block = parent_blocks[0]
assert len(parent_blocks) == parent_block.arg.cnt
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
late_ops = list(x.arg.lst)
# NOTE: we have to add a barrier at the start if barrier is used in the range
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
late_ops = [UOp(Ops.BARRIER)] + late_ops
# peephole opt, remove any BARRIERs next to each other
for i in range(len(late_ops)-1):
if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP)
arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
# else the whole context ended by the blockend is already in this block and we can safely turn it into a block
return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt))
block_merge = PatternMatcher([
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
])
# ****** finalize ******
def finalize(sink:UOp) -> UOp:
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst)))
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])

View File

@@ -4,8 +4,8 @@ from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@@ -64,21 +64,8 @@ class Scheduler:
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _globalizable_rngs(self) -> list[UOp]:
store_rngs = self.ast.src[0].src[2:]
# filter any not in local stores
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
# filter any not in reduces
# TODO: enable this
"""
reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE]
for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
"""
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[-1] == AxisType.LOOP] if store_rngs else []
# all ranges that end before any STOREs
return [x for x in self.ast.toposort(lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in self.ast.ranges]
def convert_loop_to_global(self):
if not self.opts.has_local: return None
@@ -89,11 +76,11 @@ class Scheduler:
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]:
store_rngs = flatten([x.src[2:] for x in self.ast.src])
globalizible_rngs = self._globalizable_rngs()
ret = []
for x,r in zip(self.axis_types, self.rngs):
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
elif r not in store_rngs and x == AxisType.LOOP: ret.append("BLACK")
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK")
else: ret.append(axis_colors[x])
return ret
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
@@ -348,7 +335,7 @@ def apply_opts(ctx:Renderer, ast:UOp):
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD):
if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice):
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)

View File

@@ -156,7 +156,9 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
if lib in seen_libs: continue
# filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops: continue
if least_compute_ops*1000 < this_compute_ops:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}")
continue
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))

View File

@@ -13,14 +13,16 @@ def flatten_range(r:UOp):
pm_flatten_range = PatternMatcher([
# real ranges only
(UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range),
# END is only on RANGES. TODO: this is copied from symbolic
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
])
def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}])
def simplify_merge_adjacent(u:UOp) -> UOp|None:
reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE]
i = range_start[u.op]
while i < len(u.src)-1:
r0, r1 = u.src[i], u.src[i+1]
i = 0
while i < len(u.ended_ranges)-1:
r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1]
# check same type
if r0.arg[-1] == r1.arg[-1]:
# check if the ranges to merge are in the same reduces
@@ -39,7 +41,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
return u
pm_simplify_ranges = PatternMatcher([
(UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent),
(UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent),
])
def mark_range_mod(ctx, r:UOp, c:UOp):
@@ -57,7 +59,7 @@ def do_substitute(ctx, x: UOp):
def dont_sub_ranges_for_image(ctx, x:UOp):
if isinstance(x.src[0].dtype, ImageDType):
for s in x.src[1:]: ctx[s] = None
for s in x.src[0].ranges: ctx[s] = None
pm_split_ranges = PatternMatcher([
(UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod),

View File

@@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import unwrap_class_type
from tinygrad.helpers import unwrap_class_type, suppress_finalizing
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -163,6 +163,7 @@ class Buffer:
return self._trace_num
@property
def nbytes(self): return self.size*self.dtype.itemsize
@suppress_finalizing
def __del__(self): (not hasattr(self, '_buf')) or self.deallocate()
def __repr__(self):
return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \

View File

@@ -22,18 +22,18 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if s.op is Ops.ASSIGN:
if s.op is Ops.AFTER:
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
@@ -43,7 +43,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
else:
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
# linearize KERNEL UOps into ScheduleItems in BFS order

View File

@@ -170,9 +170,9 @@ SPEC = ContextVar("SPEC", 0)
# TODO: disable by default due to speed
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0)
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
@dataclass(frozen=True)
class Metadata:
name: str

View File

@@ -36,7 +36,7 @@ class BatchNorm:
self.weight: Tensor|None = Tensor.ones(sz) if affine else None
self.bias: Tensor|None = Tensor.zeros(sz) if affine else None
self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
self.num_batches_tracked = Tensor.zeros(dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False)
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]:

View File

@@ -28,9 +28,12 @@ class Estimates:
mult_stack: list[sint] = []
dont_count: set[UOp] = set()
if ignore_indexing:
def range_gate(x): return x.op is not Ops.RANGE
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
dont_count = dont_count.union(u.src[0].toposort())
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
# TODO: is this correct? this all needs to be cleaned up
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort())
@@ -45,7 +48,7 @@ class Estimates:
mults *= cast(sint, u.src[0].ssimplify())
# SPECIAL are already counted in mults
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults

View File

@@ -1,7 +1,7 @@
from typing import Literal, Callable, cast
import os, math, sys
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.codegen.opt import tc, axis_letters
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
@@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
# r method accesses
(UPat(Ops.RANGE, name="x"),
@@ -144,6 +144,9 @@ class CStyleLanguage(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.AFTER:
r[u] = r[u.src[0]]
continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue
@@ -160,7 +163,7 @@ class CStyleLanguage(Renderer):
# naming
prefix = None
if u.op is Ops.SPECIAL: r[u] = u.arg
elif u.op is Ops.RANGE: r[u] = "ridx"+range_str(u)
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
@@ -170,7 +173,7 @@ class CStyleLanguage(Renderer):
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \

View File

@@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([
f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n"
f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
(UPat(Ops.END, name="x"), lambda ctx,x:
f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n"
f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n"
f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
@@ -167,6 +167,9 @@ class LLVMRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.AFTER:
r[u] = r[u.src[0]]
continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View File

@@ -1,4 +1,4 @@
from typing import Callable, cast
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap
from tinygrad.renderer import Renderer
@@ -169,10 +169,13 @@ class NIRRenderer(Renderer):
def render(self, uops:list[UOp]):
self.prerender(uops)
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
self.r, self.param_idx, ranges = {}, 0, []
self.r: dict[UOp, Any] = {}
self.param_idx, ranges = 0, []
for u in uops:
if u.op == Ops.NOOP or u.op == Ops.INDEX: pass
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK:
if u.arg is not None: self.b.shader.contents.info.name = mesa.char_pointer_cast(u.arg.function_name)
elif u.op == Ops.DEFINE_LOCAL:
@@ -183,7 +186,7 @@ class NIRRenderer(Renderer):
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
elif u.op == Ops.ENDRANGE:
elif u.op == Ops.END:
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]),
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break))
mesa.nir_pop_loop(self.b, None)

View File

@@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
(UPat(Ops.END, name="x", src=(UPat.var("src0"),), allow_any_len=True), lambda ctx, x, src0: [
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
@@ -179,6 +179,9 @@ class PTXRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue
@@ -216,7 +219,7 @@ class PTXRenderer(Renderer):
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
if prefix: r[u] = ssa(prefix, u, dtype)

View File

@@ -7,13 +7,14 @@
# LONGDOUBLE_SIZE is: 16
#
import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers
def brew_prefix():
try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip()
except Exception: return ''
def brew_path(nm):
try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib"
except Exception: return 'failed'
PATHS_TO_TRY = [
(BASE:=os.getenv('MESA_PATH', f"/usr{'/local/' if helpers.OSX else '/'}lib"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so'),
f'{BASE}/libtinymesa{EXT}',
f'{brew_prefix()}/lib/libtinymesa_cpu.dylib',
brew_path('tinymesa_cpu'),
brew_path('tinymesa'),
]
def _try_dlopen_tinymesa_cpu():
library = ctypes.util.find_library("tinymesa_cpu")
@@ -6087,7 +6088,7 @@ struct_nir_op_info._fields_ = [
nir_op_info = struct_nir_op_info
try: nir_op_infos = (struct_nir_op_info * 489).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_op_infos')
except AttributeError: pass
except (AttributeError, ValueError): pass
try:
nir_op_is_selection = _libraries['FIXME_STUB'].nir_op_is_selection
nir_op_is_selection.restype = ctypes.c_bool
@@ -8118,7 +8119,7 @@ c__EA_nir_intrinsic_index_flag = ctypes.c_uint32 # enum
nir_intrinsic_index_flag = c__EA_nir_intrinsic_index_flag
nir_intrinsic_index_flag__enumvalues = c__EA_nir_intrinsic_index_flag__enumvalues
try: nir_intrinsic_index_names = (ctypes.POINTER(ctypes.c_char) * 75).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_index_names')
except AttributeError: pass
except (AttributeError, ValueError): pass
class struct_nir_intrinsic_instr(Structure):
pass
@@ -8242,7 +8243,7 @@ struct_nir_intrinsic_info._fields_ = [
nir_intrinsic_info = struct_nir_intrinsic_info
try: nir_intrinsic_infos = (struct_nir_intrinsic_info * 732).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_infos')
except AttributeError: pass
except (AttributeError, ValueError): pass
try:
nir_intrinsic_src_components = _libraries['libtinymesa_cpu.so'].nir_intrinsic_src_components
nir_intrinsic_src_components.restype = ctypes.c_uint32
@@ -19877,4 +19878,4 @@ __all__ = \
'union_util_format_description_0', 'util_format_colorspace',
'util_format_layout', 'va_list']
lvp_nir_options = gzip.decompress(base64.b64decode('H4sIAAAAAAAAA2NgZGRkYGAAkYxgCsQFsxigwgwQBoxmhCqFq2WEKwIrAEGIkQxoAEMALwCqVsCiGUwLMHA0QPn29nBJkswHANb8YpH4AAAA'))
def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)')
def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)')

View File

@@ -52,11 +52,11 @@ class PythonProgram:
loop_ends: dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is Ops.ENDRANGE:
if uop is Ops.END:
loop_ends[idp[0]] = i
i = idp[0]
continue
@@ -72,7 +72,8 @@ class PythonProgram:
if g: _store(m, o+j, v, dtp[1].scalar())
i += 1
continue
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
if uop is Ops.AFTER: ul[i] = inp[0]
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")

View File

@@ -58,7 +58,7 @@ class LLVMCompiler(Compiler):
self.diag_msgs.append(msg)
self.handle_diag = handle_diag
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)

View File

@@ -69,7 +69,9 @@ class PTXCompiler(Compiler):
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
class NVPTXCompiler(PTXCompiler):
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
def __init__(self, arch:str):
nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint())))
super().__init__(arch, cache_key="nv_ptx")
def compile(self, src:str) -> bytes:
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)

View File

@@ -52,11 +52,11 @@ class IndexingContext:
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.KERNEL}: return None
if x.op is Ops.ASSIGN and x.src[1].op is Ops.KERNEL: return None
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
new_srcs = []
for s in x.src:
new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s]

View File

@@ -2,10 +2,9 @@ from typing import cast
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate
from tinygrad.uop.symbolic import symbolic_flat
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
from tinygrad.helpers import DEBUG_RANGEIFY
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@@ -137,6 +136,9 @@ def cleanup_dead_axes(b:UOp):
# move the tag to the expand. NOTE: this expand tag might not survive
return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag)
def gate_substitute(ctx, b:UOp) -> None:
if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()
pm_gate_substitute = PatternMatcher([(UPat(GroupOp.All, name="b"), gate_substitute)], compiled=False)
# if a buffer is being stored just for permutes or something, remove it
# we want to reexpress the indexes of idx2 in terms of the implied b1
def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
@@ -179,11 +181,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# if it makes it here, the bufferize is removed
# this is the ranges replaced
# NOTE: if buf src is a const, we don't replace it
if REAL_SUBSTITUTE:
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST})
else:
replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST])
return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2]))))
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
@@ -243,7 +241,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
bufs: set[UOp] = set()
def gate_input(u:UOp):
# TODO: add cache to fix n^2
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
@@ -278,7 +276,8 @@ def bufferize_to_store(x:UOp):
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
# in assign, this is the buffer size, not the bufferize size
# TODO: assign_mops here
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag)
do_store = assign_target.replace(dtype=sdtype).store(assign_src).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE])
ret = assign_target.src[0].after(do_store)
mops = []
walk = assign_mops
while walk is not assign_mops.base:
@@ -290,8 +289,8 @@ def bufferize_to_store(x:UOp):
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag)
ret = ret.forced_reshape(shape)
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE])
ret = buf.after(do_store).forced_reshape(shape)
# TODO: is this right? what if it's offset
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
@@ -302,9 +301,8 @@ def bufferize_to_store(x:UOp):
tag = x.arg.device
if tag is None: tag = UOp.unique().arg # TODO: hack
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
# store has the other dtype here
# TODO: how is this unified?
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape)
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).end(ends=[x for x in rngs if x.op is Ops.RANGE])
return buf.after(do_store).reshape(shape)
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
@@ -336,12 +334,13 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
ctx.vars[b] = None
return b.src[0]
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
buf = assign.as_buf()
def handle_after(ctx:LocalAddBufferContext, after:UOp):
if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None
buf = after.as_buf()
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
assert buf not in ctx.map
ctx.map[buf] = assign
ctx.map[buf] = after
return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
@@ -351,7 +350,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp):
return ret
def find_bufs(x:UOp):
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.ASSIGN) if s.op is Ops.INDEX]
idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX]
read_from: dict[UOp, Ops] = {}
if any((buf:=idx.as_buf()).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs):
raise RuntimeError(f"cycle detected while indexing {buf}")
@@ -360,7 +359,7 @@ to_define_global = PatternMatcher([
(UPat(Ops.STORE, name="x"), find_bufs),
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign),
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
@@ -389,16 +388,9 @@ rangeify_codegen = PatternMatcher([
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
# TODO: this can be moved into codegen
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())),
# TODO: hack for group for reduce
(UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)),
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
.f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
])
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
@@ -419,9 +411,8 @@ class Kernel:
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
def split_store(ctx:list[UOp], x:UOp):
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if len(x.ranges): return None
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
# local kernel rewrite
lctx = LocalAddBufferContext()
@@ -431,16 +422,22 @@ def split_store(ctx:list[UOp], x:UOp):
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
# NOTE: the hack for COPY is here
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \
if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1]
for u in ret.toposort():
# TODO: this can be wrong if there's multiple of these
if u.op in {Ops.COPY, Ops.BUFFER_VIEW}:
ret = u
break
else:
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
return x.as_buf().assign(kernel)
return kernel
split_kernels = PatternMatcher([
(UPat(Ops.STORE, name="x"), split_store),
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
def tag_uop(ctx:list[UOp], x:UOp):
@@ -471,25 +468,6 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
def do_sub_recurse(s:UOp):
x,keys,values = s.src[0], s.src[1].src, s.src[2].src
# SUBSTITUTE applied to SUBSTITUTE runs the child SUB on the parents. though this is probably wrong in the generic case
if x.op is Ops.SUBSTITUTE:
sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:])
sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:])
return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v))
# here we actually do the SUBSTITUTE
if x in keys: return values[keys.index(x)]
# we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size)
x_ranges = x.ranges
new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)}
# if there's no SUBSTITUTEs left, we can just return x
if len(new_kv) == 0: return x
# then we add SUBSTITUTE to all parents
uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values()))
return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src]))
pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
@@ -504,8 +482,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# TODO: can you substitute and remove costly buffers at the same time?
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
@@ -524,13 +500,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in tsink.toposort():
if u.op is not Ops.ASSIGN: continue
if u.op is not Ops.AFTER: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")

View File

@@ -249,9 +249,9 @@ class Tensor(MathTrait):
self.kernelize(*lst)
sink = UOp.sink(*[x.uop for x in (self,)+lst])
# remove all ASSIGNs, after scheduling, the tensors are just buffers
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN}
_apply_map_to_tensors(remove_assign_map, name="Remove Assigns")
# remove all AFTERs, after scheduling, the tensors are just buffers
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.AFTER}
_apply_map_to_tensors(remove_assign_map, name="Remove After")
# create the schedule
schedule, var_vals = create_schedule_with_vars(sink)

View File

@@ -12,19 +12,18 @@ class Ops(FastEnum):
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
SENTINEL = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
AFTER = auto()
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
# create buffer
BUFFERIZE = auto()
SUBSTITUTE = auto()
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
# blocks in linearizer (only used there)
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
# movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
MULTI = auto() # MULTI is really a movement op
@@ -67,7 +66,7 @@ class Ops(FastEnum):
WHERE = auto(); MULACC = auto() # noqa: E702
# control flow ops
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto() # noqa: E702
@@ -91,7 +90,6 @@ class GroupOp:
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
# BinaryOps that can be flipped
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}

View File

@@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def ptrdtype(self) -> PtrDType:
if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType")
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
return self.dtype
# *** uop shape stuff ***
@@ -179,7 +179,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
return None
# some ops init the shape
@@ -190,7 +190,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE: return self.src[0]._shape
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER | Ops.END:
return self.src[0]._shape
# ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape
@@ -275,6 +276,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
for s in UOp.sink(*self.src[range_start[self.op]:]).ranges:
if s in ret: del ret[s]
elif self.op is Ops.END:
for s in self.src[self.arg:]: ret.update(s.ranges)
for s in UOp.sink(*self.src[:self.arg]).ranges:
if s in ret: del ret[s]
else:
for s in self.src: ret.update(s.ranges)
return ret
@@ -284,6 +289,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.RANGE: return {self:None}
return self._ranges
@functools.cached_property
def ended_ranges(self):
match self.op:
case Ops.REDUCE: return self.src[1:]
case Ops.END: return self.src[:self.arg]
case _: raise RuntimeError(f"{self.op} doesn't end ranges")
# *** uop evaluation ***
def simplify(self, tracked=False, full_symbolic=True):
@@ -324,7 +336,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): return self.index(idx)
def __getitem__(self, *idx): return self.index(*idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
@@ -349,6 +361,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
def end(self, *src:UOp, ends:Sequence[UOp]):
if len(ends) == 0: return self
return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends))
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def alu(self, op, *src:UOp, **kwargs):
@@ -525,6 +541,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT:
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
return self.src[0].device[self.arg]
@@ -538,8 +555,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return self
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
assert self.op is Ops.AFTER, f"must be AFTER {self.op}"
return self.src[0].buf_uop.base
def as_buf(self) -> UOp:
if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)
@@ -810,6 +827,8 @@ class UPat(MathTrait):
@staticmethod
def any(*src): return UPatAny(src=src)
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
def or_after(self, name:str|None=None):
return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True))
@staticmethod
@functools.cache
@@ -836,7 +855,6 @@ class UPat(MathTrait):
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def fuse(self): return self.alu(Ops.FUSE)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
def or_broadcasted(self, **kwargs): return UPat.any(self, self.broadcast(**kwargs))
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
@@ -1091,7 +1109,7 @@ class RewriteContext:
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
except BottomUpGate:
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
self.replace[n] = new_n
self.replace[n] = unwrap(test_n)
continue
stack.append((n, 1, new_n))
for x in reversed(new_n.src):
@@ -1171,6 +1189,7 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
# TODO: this is only triggering if they are all casts, correct?
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View File

@@ -66,8 +66,8 @@ buffer_spec = PatternMatcher([
])
assign_spec = PatternMatcher([
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@@ -111,6 +111,9 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# REDUCE with an outerworld range
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True)
])
# ***** uop type spec *****
@@ -156,12 +159,16 @@ spec = PatternMatcher([
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x:
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# allow AFTER on buffers
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True),
# **** new style load/store ****
# make sure all index dtypes have been lowered
@@ -171,18 +178,14 @@ spec = PatternMatcher([
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD on STORE
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
# STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
# STORE takes a <bufidx, val, ranges...>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
@@ -193,7 +196,7 @@ spec = PatternMatcher([
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
(UPat(Ops.END, dtype=dtypes.void), lambda: True),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
@@ -201,9 +204,8 @@ spec = PatternMatcher([
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier?>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
@@ -234,9 +236,6 @@ full_spec = PatternMatcher([
# SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False),
# allow any SUBSTITUTE
(UPat(Ops.SUBSTITUTE), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
# where on index in rhs position is fine
@@ -269,7 +268,7 @@ full_spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs
(UPat((Ops.BLOCKSTART, Ops.BLOCK, Ops.BLOCKFINAL, Ops.BLOCKEND, Ops.KERNEL), dtype=dtypes.void), lambda: True),
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
# allow index dtype on a restricted set of UOps
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE,
@@ -283,6 +282,8 @@ full_spec = PatternMatcher([
(UPat(Ops.DEFINE_VAR), lambda: True),
# reshape on STORE
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
])+tensor_uop_spec+spec
# ***** uop helpers *****

View File

@@ -377,6 +377,13 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# END is only on RANGES
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([

View File

@@ -2,6 +2,7 @@
<html>
<head>
<title>tinygrad viz</title>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
@@ -38,24 +39,37 @@
::-webkit-scrollbar-thumb { background: #686977; }
a {
color: #4a90e2;
text-decoration: underline;
cursor: pointer;
}
ul {
padding: 0;
opacity: 0.6;
white-space: nowrap;
cursor: pointer;
}
ul.active {
ul > p {
opacity: 0.6;
}
ul.active > p {
opacity: 1;
}
ul > ul {
display: none;
margin-left: 6px;
}
ul.has-children > p::before {
content:"▸ ";
}
ul.has-children.expanded > p::before {
content:"▾ ";
}
ul.expanded > ul {
display: block;
}
ul.disabled {
ul.disabled > p {
opacity: 0.4;
}
ul.disabled {
pointer-events: none;
}
label {
@@ -137,7 +151,7 @@
.metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * {
margin-top: 12px;
}
.ctx-list > ul > * + * {
ul > * + * {
margin-top: 4px;
}
.graph {

View File

@@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) {
if (parents == null && children == null) return;
const src = [...parents, ...children, d.id];
nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id));
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath");
d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port");
e.stopPropagation();
@@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) {
}).selectAll("text").data(d => {
const ret = [[]];
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
for (const [i, l] of st.split("\n").entries()) {
if (i > 0) ret.push([]);
ret.at(-1).push({ st:l, color });
}
const lines = st.split("\n");
ret.at(-1).push({ st:lines[0], color });
for (let i=1; i<lines.length; i++) ret.push([{ st:lines[i], color }]);
}
return [ret];
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
@@ -174,10 +173,16 @@ function tabulate(rows) {
return root;
}
var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity;
var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity, shapeMetadata = new Map();
function focusShape(shape) {
saveToHistory({ shape:focusedShape });
focusedShape = shape?.key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
return document.querySelector(".metadata").replaceChildren(shapeMetadata.get(focusedShape) ?? "");
}
async function renderProfiler() {
displayGraph("profiler");
d3.select(".metadata").node().replaceChildren(focusedShape?.html ?? "");
d3.select(".metadata").node().replaceChildren(shapeMetadata.get(focusedShape) ?? "");
// layout once!
if (data != null) return updateProgress({ start:false });
const profiler = d3.select(".profiler").html("");
@@ -243,15 +248,17 @@ async function renderProfiler() {
}
const html = document.createElement("div");
html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node());
const argsDiv = document.createElement("div"); argsDiv.id = "args"; html.appendChild(document.createElement("br")); html.appendChild(argsDiv);
if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info;
if (shapeRef != null) {
const p = html.appendChild(document.createElement("p"));
p.innerText = "\nView Codegen Rewrite"; p.style.cursor = "pointer";
p.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step);
const a = html.appendChild(document.createElement("a"));
a.innerText = "\nView codegen rewrite";
a.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step);
}
// tiny device events go straight to the rewrite rule
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), html, key, ...shapeRef };
if (key != null) shapeMetadata.set(key, html);
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, ...shapeRef };
if (e.key != null) shapeMap.set(e.key, arg);
// offset y by depth
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor });
@@ -295,17 +302,30 @@ async function renderProfiler() {
const rows = [["DType", dtype], ["Len", formatUnit(sz)], ["Size", formatUnit(nbytes, "B")], ["Lifetime", formatTime(dur)]];
if (users != null) rows.push(["Users", users.length]);
const info = html.appendChild(tabulate(rows).node());
const arg = {tooltipText:info.outerHTML, key:`${k}-${num}`};
for (let u=0; u<users?.length; u++) {
const p = html.appendChild(document.createElement("p")); p.style.marginTop = "4px";
const { repr, num, mode, shape } = users[u]; p.appendChild(colored(`[${u}] ${repr} ${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`));
const { repr, num, mode, shape } = users[u];
const bufInfo = `${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`
p.appendChild(colored(`[${u}] ${repr} ${bufInfo}`));
const metadata = shape?.tooltipText?.split("\n").at(-1);
if (metadata != null) p.appendChild(document.createElement("span")).innerText = "\n"+metadata;
if (shape != null) {
p.style.cursor = "pointer";
p.onclick = () => focusShape(shape);
const args = shapeMetadata.get(shape.key).querySelector("#args");
const bufArg = d3.create("p").text(`${bufInfo} ${rows[2][1]}`).style("cursor", "pointer").style("margin-top", "4px").on("click", () => {
const device = document.getElementById(k);
if (!isExpanded(device)) device.click();
focusShape(arg);
}).node();
bufArg.dataset.num = num;
let before = null;
for (const c of args.children) { if (+c.dataset.num > num) { before = c; break; } }
args.insertBefore(bufArg, before);
}
}
const arg = {tooltipText:info.outerHTML, html, key:`${k}-${num}`};
shapeMetadata.set(arg.key, html)
shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) });
}
// generic polygon merger
@@ -338,6 +358,7 @@ async function renderProfiler() {
else if (tid === focusedDevice) { track.shapes = track.views[0]; offset += rescaleTrack(track, tid, 1/track.scaleFactor); }
}
data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null;
toggleCls(document.getElementById(focusedDevice), document.getElementById(newFocus), "expanded");
focusedDevice = newFocus;
return resize();
});
@@ -396,7 +417,7 @@ async function renderProfiler() {
lw += e.label[li].width;
}
}
if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push([p, pcolor]); }
if (focusedShape != null && e.arg?.key === focusedShape) { paths.push([p, pcolor]); }
}
}
// draw axes
@@ -463,15 +484,11 @@ async function renderProfiler() {
}
}
function focusShape(shape) {
focusedShape = shape; render(zoomLevel);
return document.querySelector(".metadata").replaceChildren(shape?.html ?? "");
}
canvas.addEventListener("click", e => {
e.preventDefault();
const foundRect = findRectAtPosition(e.clientX, e.clientY);
if (foundRect?.step != null && foundRect?.key == null) { return setCtxWithHistory(foundRect.ctx, foundRect.step); }
if (foundRect?.key != focusedShape?.key) { focusShape(foundRect); }
if (foundRect?.key != focusedShape) { focusShape(foundRect); }
});
canvas.addEventListener("mousemove", e => {
@@ -530,10 +547,10 @@ function codeBlock(st, language, { loc, wrap }={}) {
return ret;
}
function setActive(e) {
if (e == null) return;
e.classList.add("active");
requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" }));
function toggleCls(prev, next, cls, value) {
prev?.classList.remove(cls);
next?.classList.toggle(cls, value ?? true);
requestAnimationFrame(() => next?.scrollIntoView({ behavior: "auto", block: "nearest" }));
}
// ** hljs extra definitions for UOps and float4
@@ -563,31 +580,41 @@ const evtSources = [];
// context: collection of steps
const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false};
function setState(ns) {
const { currentCtx:prevCtx, currentStep:prevStep } = state;
const { ctx:prevCtx, step:prevStep } = select(state.currentCtx, state.currentStep);
Object.assign(state, ns);
// update element styles if needed
document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps);
if (state.currentCtx !== prevCtx) {
document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded");
setActive(document.getElementById(`ctx-${state.currentCtx}`));
}
if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) {
document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active");
setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`));
const { ctx, step } = select(state.currentCtx, state.currentStep);
toggleCls(prevCtx, ctx, "expanded", state.expandSteps);
if (ctx?.id !== prevCtx?.id) toggleCls(prevCtx, ctx, "active");
if (ctx?.id !== prevCtx?.id || step?.id !== prevStep?.id) {
toggleCls(prevStep, step, "active");
// walk the tree back until all parents expanded so that the child is visible
let e = step;
while (e?.parentElement?.id.startsWith("step")) {
e.parentElement.classList.add("expanded");
e = e.parentElement;
}
}
// re-render
main();
}
const getSubrewrites = (ul) => ul.querySelectorAll(":scope > ul");
function saveToHistory(ns) {
// NOTE: browser does a structured clone, passing a mutable object is safe.
history.replaceState(ns, "");
history.pushState(ns, "");
}
// set a new context and keep the old one in browser history
function setCtxWithHistory(newCtx, step=0) {
// NOTE: browser does a structured clone, passing a mutable object is safe.
history.replaceState(state, "");
history.pushState(state, "");
saveToHistory(state);
setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 });
}
window.addEventListener("popstate", (e) => {
if (e.state?.shape != null) return focusShape({ key:e.state?.shape });
if (e.state != null) setState(e.state);
});
@@ -605,15 +632,25 @@ async function main() {
p.onclick = () => {
setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 });
}
const stack = []; let list = ul;
for (const [j,u] of steps.entries()) {
const inner = ul.appendChild(document.createElement("ul"));
inner.id = `step-${i}-${j}`;
inner.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : '');
inner.style.marginLeft = `${8*u.depth}px`;
inner.onclick = (e) => {
while (stack.length && stack.at(-1).depth >= u.depth) stack.pop();
const list = stack.length > 0 ? stack.at(-1).li : ul;
u.li = list.appendChild(document.createElement("ul"));
u.li.id = `step-${i}-${j}`;
const p = u.li.appendChild(document.createElement("p"));
p.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : '');
p.onclick = (e) => {
e.stopPropagation();
setState({ currentStep:j, currentCtx:i, currentRewrite:0 });
const subrewrites = getSubrewrites(e.currentTarget.parentElement);
if (subrewrites.length) { e.currentTarget.parentElement.classList.toggle("expanded"); }
setState({ currentStep:j, currentCtx:i });
}
stack.push(u);
}
for (const l of ul.querySelectorAll("ul > ul > p")) {
const subrewrites = getSubrewrites(l.parentElement);
if (subrewrites.length > 0) { l.innerText += ` (${subrewrites.length})`; l.parentElement.classList.add("has-children"); }
}
}
return setState({ currentCtx:-1 });
@@ -706,8 +743,9 @@ async function main() {
rewriteList.className = "rewrite-list";
for (let s=0; s<=step.match_count; s++) {
const ul = rewriteList.appendChild(document.createElement("ul"));
ul.innerText = s;
ul.id = `rewrite-${s}`;
const p = ul.appendChild(document.createElement("p"));
p.innerText = s;
ul.onclick = () => setState({ currentRewrite:s });
ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : "";
if (s > 0 && s === currentRewrite) {
@@ -762,22 +800,32 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid
// **** keyboard shortcuts
const select = (ctx, step) => ({ ctx:document.getElementById(`ctx-${ctx}`), step:document.getElementById(`step-${ctx}-${step}`) });
const deselect = (element) => {
const parts = element?.id.split("-").map(Number);
return element?.id.startsWith("ctx") ? { ctx:parts[1], step:null } : element?.id.startsWith("step") ? {ctx:parts[1], step:parts[2]} : {};
}
const isExpanded = (el) => el?.classList.contains("expanded");
document.addEventListener("keydown", (event) => {
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
// up and down change the step or context from the list
const changeStep = expandSteps && ctxs[currentCtx].steps?.length;
const { step, ctx } = select(currentCtx, currentStep);
if (event.key == "ArrowUp") {
event.preventDefault();
if (changeStep) {
return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) });
let prev = deselect(step.previousElementSibling);
if (prev.step == null && isExpanded(step.parentElement)) prev = deselect(step.parentElement);
return prev.step != null && !isExpanded(step) && setState({ currentRewrite:0, currentStep:prev.step });
}
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false });
}
if (event.key == "ArrowDown") {
event.preventDefault();
if (changeStep) {
const totalUOps = ctxs[currentCtx].steps.length-1;
return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) });
const next = deselect(isExpanded(step) ? step.children[1] : step.nextElementSibling);
return next.step != null && setState({ currentRewrite:0, currentStep:next.step });
}
return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false });
}
@@ -787,6 +835,7 @@ document.addEventListener("keydown", (event) => {
if (currentCtx === -1) {
return setState({ currentCtx:0, expandSteps:true });
}
if (expandSteps && getSubrewrites(step).length) return step.children[0].click();
return setState({ expandSteps:!expandSteps });
}
// left and right go through rewrites in a single UOp

View File

@@ -17,10 +17,10 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00"}
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
# VIZ API
@@ -33,8 +33,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]:
steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc),
"query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec):
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"})
steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src", "depth":0})
steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm", "depth":0})
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "steps":steps})
return ret
@@ -71,7 +71,7 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src):
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
@@ -82,6 +82,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
if u.op is Ops.END:
label += "\n"+' '.join([f"{colored(u.src[i].arg[0], axis_colors[u.src[i].arg[-1]])}({u.src[i].vmax+1})" for i in range(u.arg)])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
@@ -98,9 +100,11 @@ def _reconstruct(a:int):
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"})
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None,
"diff":None, "upat":None}
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
ignore_indexing = trace.keys[i].display_name.startswith("Schedule") and not (ctx.name in {"kernel split"} or \
any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,)))
yield {"graph":uop_to_json(next_sink, ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)