From 4e46c6732768126decd08d0abed002c1a0f1840e Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 1 Dec 2024 21:50:47 -0500 Subject: [PATCH 01/23] small helpers cleanups (#7977) less lines for ceildiv and partition, and removed one # noqa: E501 --- tinygrad/helpers.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4d826cea95..5278696b9c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -41,9 +41,7 @@ def fully_flatten(l): return [l] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst -def ceildiv(num, amt): - ret = -(num//-amt) - return ret if not isinstance(ret, float) else int(ret) +def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt def data64(data:Any) -> Tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint def data64_le(data:Any) -> Tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint @@ -52,10 +50,9 @@ def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" return {k:v for d in ds for k,v in d.items()} def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]: - a:List[T] = [] - b:List[T] = [] - for s in itr: (a if fxn(s) else b).append(s) - return a,b + ret:Tuple[List[T], List[T]] = ([], []) + for s in itr: (ret[0] if fxn(s) else ret[1]).append(s) + return ret def unwrap(x:Optional[T]) -> T: assert x is not None return x @@ -268,7 +265,8 @@ def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B") def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv)) -def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501 +def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): + return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) @functools.lru_cache(maxsize=None) def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]): class CStruct(ctypes.Structure): From 6c1efb9a726f5fab0c0f95f190242e31b53fb3f9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 2 Dec 2024 11:08:24 +0800 Subject: [PATCH 02/23] hotfix: amd gemv was flaky --- test/external/speed_v_theoretical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py index b4be246ae7..0dc2284ed9 100644 --- a/test/external/speed_v_theoretical.py +++ b/test/external/speed_v_theoretical.py @@ -91,7 +91,7 @@ class TestKernelSpeed(unittest.TestCase): def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=130, amd_tflops=70) def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400) - def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=400) + def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400 # TODO: tiny7 is slower than tiny12 def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18) From 254c86d7128c67e59464305944feab28b3270291 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 1 Dec 2024 22:35:21 -0500 Subject: [PATCH 03/23] ruff target-version "py38" -> "py310" (#7978) --- ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 83c85189f5..7912b6f54a 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,6 @@ indent-width = 2 preview = true -target-version = "py38" +target-version = "py310" lint.select = [ "F", # Pyflakes From d53cd92364d8d000ebe71e19d2bb0d2235648bef Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:00:48 +0800 Subject: [PATCH 04/23] fix tests for delete lazy [pr] (#7980) --- test/test_dtype.py | 4 +++- test/test_tiny.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index d5e253ad8a..471a9a669f 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -120,7 +120,9 @@ class TestDType(unittest.TestCase): data = [1., 2., 0., 0.5, -1.5, 5.25] for dt in dtypes: arr = np.asarray(data).astype(dt) - tin = Tensor(arr).numpy() + tensor = Tensor(arr) + if not is_dtype_supported(tensor.dtype): continue + tin = tensor.numpy() tor = torch.as_tensor(arr).detach().numpy() assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) diff --git a/test/test_tiny.py b/test/test_tiny.py index cd276f4141..04544d825f 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -46,8 +46,9 @@ class TestTiny(unittest.TestCase): nonlocal cnt cnt += 1 return a+b - fa,fb = Tensor([1.,2,3]), Tensor([4.,5,6]) - for _ in range(3): fxn(fa, fb) + for _ in range(3): + fa,fb = Tensor([1.,2,3]), Tensor([4.,5,6]) + fxn(fa, fb) # function is only called twice self.assertEqual(cnt, 2) From 90e2b2d5773e2494fba38d2235bfc1e3f7e6e434 Mon Sep 17 00:00:00 2001 From: mesozoic-egg <133102390+mesozoic-egg@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:33:16 +0800 Subject: [PATCH 05/23] Remove gated store, put rewrite to uopgraph [pr] (#7975) * update test for gated store * put gated store rewrite to uopgraph, rm from ptx * update test update test update test * remove gated st rewrite in llvm * lint --------- Co-authored-by: Mesozoic Egg Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- test/test_uop_graph.py | 3 ++- test/test_uops.py | 27 ++++++++++++++------------- tinygrad/codegen/uopgraph.py | 3 +++ tinygrad/renderer/cstyle.py | 3 --- tinygrad/renderer/llvmir.py | 4 ---- tinygrad/renderer/ptx.py | 3 +-- 6 files changed, 20 insertions(+), 23 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 85396d7b1a..ea7af9f001 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -645,7 +645,8 @@ class TestLoadStoreFolder(unittest.TestCase): assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 one_store = [x for x in sink.sparents if x.op is Ops.STORE][0] assert len(one_store.src) == 3 - assert str(one_store.src[2]) == str(gate) # huh, why do i need str here? + _if_node = one_store.src[2] + assert _if_node.op == Ops.IF and _if_node.src[0] == gate def test_simple_store_dont_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) diff --git a/test/test_uops.py b/test/test_uops.py index 85896bf82b..d607f27fd5 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -247,58 +247,59 @@ class TestConstantFolding(unittest.TestCase): assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast" class TestGatedStoreRewrite(unittest.TestCase): - @unittest.expectedFailure def test_tiny_gate_store(self): gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) - idx = gidx0 * UOp.const(dtypes.int, 2) + idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2))) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) - store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate)) + store = UOp(Ops.STORE, dtypes.void, (idx, val, gate)) uops = to_uops_list([store]) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop - gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) + gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, Ops.STORE) - @unittest.expectedFailure def test_gate_some_stores(self): gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) - idx = gidx0*UOp.const(dtypes.int, 2) + idx = gidx0 * UOp.const(dtypes.int, 2) + idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx)) + idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx)) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) - stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)] - uops = linearize_uop(stores) + stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)] + uops = to_uops_list(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop - gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) + gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, Ops.STORE) # scaled down version of TestLinearizerDumb.test_unmerged_ifs - @unittest.expectedFailure def test_merge_ifs_alt(self): gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) + idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx)) + idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx)) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) - stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)] - uops = linearize_uop(stores) + stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)] + uops = to_uops_list(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) ifs = [u for u in uops if u.op is Ops.IF] endifs = [u for u in uops if u.op is Ops.ENDIF] self.assertEqual(len(ifs), 1) self.assertEqual(len(endifs), 1) - gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])]) + gated_uops = tuple(uops[uops.index(ifs[0])+1:uops.index(endifs[0])]) self.assertEqual(len(gated_uops), 2) for x in gated_uops: self.assertIs(x.op, Ops.STORE) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 610c36e5ad..08828328c2 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -483,6 +483,9 @@ pm_render = PatternMatcher([ # move masks of loads/stores (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))), masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask), + # gate any stores that aren't gated with ifs + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), + lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), ]) # *** uop graph *** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 26d2818488..42d47424f8 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -57,9 +57,6 @@ extra_pm = PatternMatcher([ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? (UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), - # gate any stores that aren't gated with ifs - (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), - lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 64e645b697..f9f7d59a5f 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -85,10 +85,6 @@ class LLVMRenderer(Renderer): (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))), # rewrite cast to bool to CMPNE 0 (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), - # *** also in cstyle *** - # gate any stores that aren't gated with ifs - (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), - lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), # rewrite MAX to CMPLT + WHERE (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index e2c6861b80..1de0a2878d 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -76,8 +76,7 @@ def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.i string_rewrite = PatternMatcher([ (UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"), (UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"), - (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), lambda ctx, x, bidx, var, pred=None: - f"{f'@{ctx.r[pred]} ' if pred is not None and pred.op is not Ops.IF else ''}st.{mem_type(bidx)}" + \ + (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \ f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \ f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"), From 9b0859d71780fef5cf3831e317f74e53f2483229 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:34:42 +0800 Subject: [PATCH 06/23] PYTHON device is okay to use everywhere [pr] (#7981) --- tinygrad/device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 9b515d33f7..98b0aaa821 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -20,7 +20,7 @@ class _Device: @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Compiled: cpn = multiprocessing.current_process().name - assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}" + assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}" x = ix.split(":")[0].upper() ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \ if (cname.lower() == x.lower() + "device")][0](ix) From cbcc1c20eb09a1342f6581cfbb99632bade982a8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:43:09 +0800 Subject: [PATCH 07/23] second try at block linearize (#7892) * second try at block linearize * weeee, works for lil matmul * it's so beautiful * test tiny passes * fix bugs * combine matching BLOCKENDS * wrapping * test lin failures passes * those failures were fake * flip sort order * fix ptx tests * deal with store better * dumb ptx fix * expect less * reduce lines * reduce lines * less lines and cleaner * no defaultdict * tighter * simpler block_parent_count --- test/external/speed_v_theoretical.py | 2 +- tinygrad/codegen/linearize.py | 228 ++++++++++++++++++--------- 2 files changed, 151 insertions(+), 79 deletions(-) diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py index 0dc2284ed9..88bfc7bfa5 100644 --- a/test/external/speed_v_theoretical.py +++ b/test/external/speed_v_theoretical.py @@ -88,7 +88,7 @@ class TestKernelSpeed(unittest.TestCase): # def test_gemm_1024(self): self._test_matmul(1024, nv_tflops=8, amd_tflops=7) # def test_gemm_2048(self): self._test_matmul(2048, nv_tflops=50, amd_tflops=30) def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70) - def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=130, amd_tflops=70) + def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70) def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400) def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400 diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 356835b7bc..aab5ad3e2b 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,92 +1,164 @@ -from typing import List, Set, Dict, Tuple -import functools, heapq -from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp -from tinygrad.dtype import dtypes -from tinygrad.helpers import DEBUG +from typing import List, Dict, Tuple +import functools, collections +from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite +from tinygrad.dtype import dtypes, PtrDType +from tinygrad.helpers import dedup, flatten, partition -def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]): - if u in children: return srcs[u] - srcs[u] = {} - children[u] = [] - for x in u.src: - srcs[u].update(get_children_dfs(x, children, srcs, in_degree)) - if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None - children[x].append(u) - in_degree[u] = len(u.src) - return srcs[u] +DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, + Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} + +def disp(y:UOp) -> str: + if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0]) + if y.op is Ops.IF: return f'IF{id(y)}' + if y.op is Ops.RANGE: return str(y.arg[0]) + return "" + +class BasicBlock: + def __init__(self, ctx, lst, end=None): + self.ctx, self.lst, self.end = ctx, lst, end + def __repr__(self): + return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\ + f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst]) + +def append_to_block(ctx, x:UOp): + block_ctxs, children = ctx + new_srcs: List[UOp] = [] + to_append: List[UOp] = [] + new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {} + in_this_block = set(x.arg.lst) + for u in x.src: + if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0: + # if it's a fork or not placed, we don't place it + new_srcs.append(u) + elif (block_ctx:=block_ctxs[u]) == x.arg.ctx: + # if it's the same context, we place the UOp in this block and append the parents to it's srcs + new_srcs += list(u.src) + to_append.append(u) + else: + # otherwise, we create a new block with this UOp + new_blocks.setdefault(block_ctx, []).append(u) + if len(to_append) == 0 and len(new_blocks) == 0: return None + + for rng,lst in new_blocks.items(): + new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, lst)) + lrng = list(rng) + for r in rng[::-1]: + if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART: + lrng.remove(r) + new_block = UOp(Ops.BLOCKEND, src=(new_block,), arg=BasicBlock(lrng[:], [UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))], r)) + new_srcs.append(new_block) + return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, to_append+x.arg.lst)) + +make_basic_blocks = PatternMatcher([ + (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock([], [x]))), + (UPat(Ops.BLOCK, name="x"), append_to_block), +]) + +def block_merge(ctx, x:UOp): + # ctx is children here + if x.op is Ops.BLOCKEND: + # if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here + in_this_block = set(x.arg.lst) + if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0: + # find the parent block that has the BLOCKSTART in the ctx + parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx] + if len(parent_blocks) == 1: + parent_block = parent_blocks[0] + # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if) + early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src) + return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src, + BasicBlock([y for y in x.arg.ctx if y is not x.arg.end], early_ops+parent_block.arg.lst+late_ops)) + assert not len(parent_blocks) + + new_srcs: List[UOp] = [] + to_append: List[UOp] = [] + new_ctx = list(x.arg.ctx[:]) + placed = set() + for u in x.src: + if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)): + # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK + new_ctx += u.arg.ctx + new_srcs += list(u.src) + to_append += u.arg.lst + elif u.op is Ops.BLOCKFORK and len([y for y in x.src if y is u]) == u.arg: # block fork appears # of times in srcs + if u not in placed: + new_srcs += list(u.src) + placed.add(u) + else: + # keep it in srcs + new_srcs.append(u) + if len(to_append) == 0 and len(placed) == 0: return None + return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(dedup(new_ctx), to_append+x.arg.lst, x.arg.end)) + +pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" - # filter nodes that don't link to a sink - # BFS toposort + + @functools.lru_cache(None) + def get_block_ctx(x:UOp) -> Tuple[UOp, ...]: + ret: List[UOp] = [] + for u in x.src: + if u.op in {Ops.RANGE, Ops.IF}: ret.append(u) + # don't flow (fully) through assign and store + elif u.op is Ops.STORE: + # ugh, deal with non-reduce locals. probably wrong + if isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.local: + idx_context, store_context = get_block_ctx(u.src[0]), get_block_ctx(u) + ret += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE] + elif u.op is Ops.ASSIGN: + # flow though assign, but remove the ranges used in the assign + assert u.src[0].op is Ops.DEFINE_ACC + ret += [x for x in get_block_ctx(u.src[1]) if x not in u.src[0].src[1:]] + else: + # flow though everything else + ret += get_block_ctx(u) + return tuple(dedup(sorted(ret, key=lambda x: x.tuplize))) + + # get children and all block contexts + block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {} children: Dict[UOp, List[UOp]] = {} - range_srcs: Dict[UOp, Dict[UOp, None]] = {} - in_degree: Dict[UOp, int] = {} - get_children_dfs(sink, children, range_srcs, in_degree) + for u in sink.sparents: + for s in u.src: children.setdefault(s, []).append(u) + this_block_ctx = get_block_ctx(u) + block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + this_block_ctx) if u.op in {Ops.IF, Ops.RANGE} else this_block_ctx - @functools.lru_cache(None) - def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]: - if x.op is Ops.SINK: return set() - return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end])) + # TODO: there's probably a clever way to remove this while loop + while 1: + sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children)) - # scope children impact the toposort and END* insertion - scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} - range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE} + # add BLOCKFORK (slow!) + block_parent_count = collections.Counter(flatten([x.src for x in sink.sparents if x.op is Ops.BLOCK])) + non_block_parents = flatten([x.src for x in sink.sparents if x.op is not Ops.BLOCK]) + forks = {} + for u,child_count in block_parent_count.items(): + if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents: + forks[u] = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], [u])),), arg=child_count) - # assign priorities - def get_priority(u:UOp): - priority = 0 - # prefer ranges that depend on the least number of independent ranges - if u.op is Ops.RANGE and u.arg[1]: - priority += u.arg[0] - for p in range_phi[u]: - priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])]) - elif u.op is Ops.CONST: - # place consts first here, they don't do anything and it can cause issues with DEFINE_ACC - priority -= 100000000000 - else: - # prefer uops that are loop children - priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss]) - if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty - return priority - priorities:Dict[UOp, int] = {u:get_priority(u) for u in children} + if not len(forks): break + sink = sink.substitute(forks) - # prevent priority inversion - @functools.lru_cache(None) - def fix_priority(u:UOp, lowest_priority): - if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}: - priorities[u] = min(priorities[u], lowest_priority) - if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here) - for x in u.src: fix_priority(x, priorities[u]) - fix_priority(sink, 0) + # combine matching BLOCKENDS + blockends_to_arg: Dict[UOp, List[UOp]] = {} + for be in sink.sparents: + if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) + new_forks = {} + for k,v in blockends_to_arg.items(): + # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails + if len(v) > 1: + new_blockend = UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)), arg=BasicBlock(dedup(flatten([y.arg.ctx for y in v])), v[0].arg.lst, k)) + out = UOp(Ops.BLOCKFORK, src=(new_blockend,), arg=len(v)) + for u in v: new_forks[u] = out + sink = sink.substitute(new_forks) - # NOTE: the compare should never make it all the way to u - queue:List[Tuple[int, Tuple, UOp]] = [] - def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) + # final rewrite to merge all blocks into one + sink = graph_rewrite(sink, pm_block_merge, ctx=children) - for u in children: - if in_degree[u] == 0: push(u) - - scope_end: Dict[UOp, UOp] = {} - _uops: List[UOp] = [] - while queue: - p,_,x = heapq.heappop(queue) - if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg) - if x in scope_children: scope_end[x] = x - if x.op is Ops.DEFINE_ACC: - idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE]) - _uops.insert(idx, x) - else: _uops.append(x) - for u, ss in scope_children.items(): - if x in ss: - ss.remove(x) - if len(ss) == 0: scope_end[u] = x - for u in children[x]: - in_degree[u] -= 1 - if in_degree[u] == 0: push(u) - - # end scopes in toposort order - for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,))) + # there should just be one block left, with a few parents with 0 srcs + assert sink.op is Ops.BLOCK + _uops = sorted(dedup(sink.src), key=lambda x: x.tuplize) + assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops) + _uops += sink.arg.lst # sanity checks (NOTE: these can cause things to be skipped in BEAM) if not skip_check: type_verify(_uops) From b797aee720b48900f8222673a5636256c55ec5b5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 2 Dec 2024 01:45:17 -0500 Subject: [PATCH 08/23] uop global buf number tracking try 2 [pr] (#7912) * uop buffer init small refactor [pr] * add early * this way it doesn't need late * buffer_num * itertools.count * count from 0 * down to 380 --- test/external/speed_v_theoretical.py | 4 ++-- tinygrad/engine/schedule.py | 4 ++-- tinygrad/ops.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py index 88bfc7bfa5..35166a2c82 100644 --- a/test/external/speed_v_theoretical.py +++ b/test/external/speed_v_theoretical.py @@ -90,11 +90,11 @@ class TestKernelSpeed(unittest.TestCase): def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70) def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70) - def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400) + def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400 def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400 # TODO: tiny7 is slower than tiny12 def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7a279e030e..6c2f635fd8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -65,7 +65,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache assert buf.op is not None, f"base must be base itself {buf}" dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base if buf.is_realized: - ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers)) + ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype) buffers[ubuf] = buf.buffer op = None elif buf.op is Ops.ASSIGN: @@ -73,7 +73,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache ctx.assigns.add(ubuf:=target.buf_uop) op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg) else: - ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers)) + ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype) buffers[ubuf] = buf.buffer op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg) cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b67003ff5d..fb8008df60 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -378,8 +378,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop Buffer stuff *** + buffer_num = itertools.count(0) @staticmethod - def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype))) + def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype))) @functools.cached_property def device(self) -> str: match self.op: From b09310d8c22da553a90f9da43b9c708e8ef00eb8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:46:16 +0800 Subject: [PATCH 09/23] add toposort method to uops, faster linearize [pr] (#7982) * add toposort method to uops, faster linearize [pr] * trust the toposort * all toposort * Revert "all toposort" This reverts commit db123adfda823ad262a736f1e1df96fef80b3ca8. --- tinygrad/codegen/linearize.py | 44 ++++++++++++++++++----------------- tinygrad/ops.py | 9 +++++++ 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index aab5ad3e2b..2a727254d8 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,5 +1,5 @@ from typing import List, Dict, Tuple -import functools, collections +import collections from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import dedup, flatten, partition @@ -95,33 +95,35 @@ pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), blo def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" - @functools.lru_cache(None) - def get_block_ctx(x:UOp) -> Tuple[UOp, ...]: - ret: List[UOp] = [] - for u in x.src: - if u.op in {Ops.RANGE, Ops.IF}: ret.append(u) + # get children and all block contexts + temp_block_ctxs: Dict[UOp, List[UOp]] = {} + children: Dict[UOp, List[UOp]] = {} + for u in sink.toposort: + this_block_ctx: List[UOp] = [] + for s in u.src: + # save children + children.setdefault(s, []).append(u) + # compute block ctx + if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s) # don't flow (fully) through assign and store - elif u.op is Ops.STORE: + elif s.op is Ops.STORE: # ugh, deal with non-reduce locals. probably wrong - if isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.local: - idx_context, store_context = get_block_ctx(u.src[0]), get_block_ctx(u) - ret += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE] - elif u.op is Ops.ASSIGN: + if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local: + idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s] + this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE] + elif s.op is Ops.ASSIGN: # flow though assign, but remove the ranges used in the assign - assert u.src[0].op is Ops.DEFINE_ACC - ret += [x for x in get_block_ctx(u.src[1]) if x not in u.src[0].src[1:]] + assert s.src[0].op is Ops.DEFINE_ACC + this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]] else: # flow though everything else - ret += get_block_ctx(u) - return tuple(dedup(sorted(ret, key=lambda x: x.tuplize))) + this_block_ctx += temp_block_ctxs[s] + temp_block_ctxs[u] = dedup(sorted(this_block_ctx, key=lambda x: x.tuplize)) - # get children and all block contexts + # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {} - children: Dict[UOp, List[UOp]] = {} - for u in sink.sparents: - for s in u.src: children.setdefault(s, []).append(u) - this_block_ctx = get_block_ctx(u) - block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + this_block_ctx) if u.op in {Ops.IF, Ops.RANGE} else this_block_ctx + for u in sink.toposort: + block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u]) # TODO: there's probably a clever way to remove this while loop while 1: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fb8008df60..1ae252a645 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -253,6 +253,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property # parents with self def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} + # TODO: replace usage of sparents with this + @functools.cached_property + def toposort(self) -> Dict[UOp, None]: + nodes: Dict[UOp, None] = {} + # NOTE: this is a lot faster than the comprehension in parents + for parent in self.src: nodes.update(parent.toposort) + nodes[self] = None + return nodes + @functools.cached_property def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src)) From f17af70d17e088fd789995a5c0eadf85d1a1f404 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:00:30 +0800 Subject: [PATCH 10/23] replace all sparents with toposort (#7983) --- test/test_linearizer.py | 8 ++++---- test/test_linearizer_dumb.py | 2 +- test/test_linearizer_failures.py | 2 +- test/test_schedule.py | 2 +- test/test_uop_graph.py | 24 ++++++++++++------------ tinygrad/codegen/linearize.py | 6 +++--- tinygrad/codegen/lowerer.py | 2 +- tinygrad/codegen/uopgraph.py | 4 ++-- tinygrad/ops.py | 19 ++++++++----------- tinygrad/renderer/__init__.py | 2 +- tinygrad/renderer/cstyle.py | 2 +- tinygrad/shape/shapetracker.py | 6 +++--- tinygrad/viz/serve.py | 4 ++-- 13 files changed, 40 insertions(+), 43 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index e0dd59f734..10b1a6e294 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -96,7 +96,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0] stores = [u for u in lin.uops if u.op is Ops.STORE] - mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL] for u in stores])) + mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 assert [u.arg for u in mutable_bufs] == [0, 1] @@ -988,10 +988,10 @@ class TestLinearizer(unittest.TestCase): # the first store is to lds and can be upcasted assert stores[0].src[-1].dtype == dtypes.float.vec(4) - assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].sparents) + assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort) # the second store is to gds with no upcasts assert stores[1].src[-1].dtype == dtypes.float - assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].sparents) + assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort) def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() @@ -1155,7 +1155,7 @@ class TestLinearizer(unittest.TestCase): def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) - loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is Ops.SPECIAL] for x in idxs])) + loop_idxs = dedup(flatten([[y for y in x.toposort if y.op is Ops.SPECIAL] for x in idxs])) loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0]) sizes = [x.arg[1] for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index d9e7a0e058..f8589ff03b 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -95,7 +95,7 @@ class TestLinearizerDumb(unittest.TestCase): print(prg.src) if_uops = [u for u in k.uops if u.op is Ops.IF] self.assertIn(len(if_uops), {1,2,3}) - conditions = if_uops[0].src[0].sparents + conditions = if_uops[0].src[0].toposort self.assertLessEqual(len(conditions), 9) # this was a bug in embedding, someday we should fold this anyway diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 6dc6e4a97b..80ca7fd6e8 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1045,7 +1045,7 @@ class TestLinearizerFailures(unittest.TestCase): ifs = [u for u in k.uops if u.op is Ops.IF] self.assertEqual(len(ifs), 3) #for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) - self.assertLessEqual(len(ifs[0].src[0].sparents), 17) + self.assertLessEqual(len(ifs[0].src[0].toposort), 17) def test_failure_45(self): ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( diff --git a/test/test_schedule.py b/test/test_schedule.py index d419eb6f11..6dd758171d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1673,7 +1673,7 @@ class TestIndexing(unittest.TestCase): @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) -def swizzle_cnt(u:UOp) -> int: return len([x for x in u.sparents if x.op is Ops.VIEW and len(x.src) != 0]) +def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0]) class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index ea7af9f001..f0b04db446 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -60,7 +60,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase): new_sink = full_graph_rewrite(lower_sink) et = time.perf_counter() - st UOp.__init__ = old_init - print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.sparents)} -> {len(new_sink.sparents)}, creating {cnt[0]} uops") + print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort)} -> {len(new_sink.toposort)}, creating {cnt[0]} uops") class TestGraphRewriteConst(unittest.TestCase): def test_gep_const(self): @@ -106,7 +106,7 @@ class TestGraphRewrite(unittest.TestCase): a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) sink = a1.sink(a2) - define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR] + define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort if x.op is Ops.DEFINE_VAR] self.assertEqual(len(define_vars), 1) def test_simple(self): @@ -187,7 +187,7 @@ class TestGraphRewrite(unittest.TestCase): print(sink.render()) self.assertEqual(sink.op, Ops.ADD) self.assertEqual(sink.src[1].op, Ops.CONST) - self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1) + self.assertEqual(len([x for x in sink.toposort if x.op is Ops.CONST]), 1) class TestUOpGraph(unittest.TestCase): def test_add_constant_fold(self): @@ -600,14 +600,14 @@ class TestLoadStoreFolder(unittest.TestCase): sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 + assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1 def test_two_load_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2 + assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2 def test_simple_load_fold_gated(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) @@ -615,8 +615,8 @@ class TestLoadStoreFolder(unittest.TestCase): load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 - single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0] + assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1 + single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0] self.assertEqual(single_load.src[1].op, Ops.VECTORIZE) def test_simple_load_dont_fold_different_gated(self): @@ -627,14 +627,14 @@ class TestLoadStoreFolder(unittest.TestCase): UOp.const(dtypes.float, 0))) for i in range(4)] sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3 + assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3 def test_simple_store_fold(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 + assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1 def test_simple_store_fold_gate(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) @@ -642,8 +642,8 @@ class TestLoadStoreFolder(unittest.TestCase): load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 - one_store = [x for x in sink.sparents if x.op is Ops.STORE][0] + assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1 + one_store = [x for x in sink.toposort if x.op is Ops.STORE][0] assert len(one_store.src) == 3 _if_node = one_store.src[2] assert _if_node.op == Ops.IF and _if_node.src[0] == gate @@ -656,7 +656,7 @@ class TestLoadStoreFolder(unittest.TestCase): UOp.const(dtypes.float, i))) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3 + assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3 class TestIFUOps(unittest.TestCase): def test_create_ifs(self): diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 2a727254d8..e4d65b2148 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -130,8 +130,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children)) # add BLOCKFORK (slow!) - block_parent_count = collections.Counter(flatten([x.src for x in sink.sparents if x.op is Ops.BLOCK])) - non_block_parents = flatten([x.src for x in sink.sparents if x.op is not Ops.BLOCK]) + block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK])) + non_block_parents = flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]) forks = {} for u,child_count in block_parent_count.items(): if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents: @@ -142,7 +142,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # combine matching BLOCKENDS blockends_to_arg: Dict[UOp, List[UOp]] = {} - for be in sink.sparents: + for be in sink.toposort: if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) new_forks = {} for k,v in blockends_to_arg.items(): diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index dd77fff257..b9737f9ec0 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -55,7 +55,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: full_shape = ast.full_shape first_upcasted = len(full_shape)-ki.upcasted # if there's no reduce, this is first_upcasted. assumes reduces are at the end - first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS)) + first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS)) local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL] # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 08828328c2..471eda0478 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -221,7 +221,7 @@ def no_vectorized_wmma(wmma:UOp): return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex)) def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): - reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents) + reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort) if len(reduce_unparented) == 0: return None new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented)) ret = new_acc.assign(new_acc.alu(alu.op, ret)) @@ -447,7 +447,7 @@ devectorize = PatternMatcher([ ]) def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]: - if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None + if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None # remove the gate from the index return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1ae252a645..e118cdc0ac 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -250,10 +250,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}} - @functools.cached_property # parents with self - def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} - # TODO: replace usage of sparents with this @functools.cached_property def toposort(self) -> Dict[UOp, None]: nodes: Dict[UOp, None] = {} @@ -422,12 +419,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def val(self) -> int: return self.unbind()[1] def vars(self) -> Set[UOp]: - bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR]) + bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR]) bound_var_base = set(x.src[0] for x in bound_vars) - all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR]) + all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> List[Variable]: - st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer] + st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer] return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) # *** uop symbolic stuff *** @@ -484,7 +481,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def _sym_fxn(self): sself = self.simplify() - varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR) + varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR) # TODO: sanitize varnames, or don't use naked eval while staying fast return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used @@ -542,10 +539,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: if ignore_indexing: for u in uops: if u.op in {Ops.LOAD, Ops.STORE}: - dont_count = dont_count.union(u.src[0].sparents) - if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents) + dont_count = dont_count.union(u.src[0].toposort) + if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort) elif u.op is Ops.IF: - dont_count = dont_count.union(u.src[0].sparents) + dont_count = dont_count.union(u.src[0].toposort) for u in uops: if u.op is Ops.RANGE: mult_stack.append(mults) @@ -1056,7 +1053,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) # try checking the whole clause - if expr in uop.sparents: + if expr in uop.toposort: candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))]) for candidate in candidates: diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 0a1ac9a481..d2cbdf4a95 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -44,7 +44,7 @@ class ProgramSpec: for u in self.uops: if u.op is Ops.DEFINE_VAR: self.vars.append(u) if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg) - if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL]) + if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL]) if u.op is Ops.SPECIAL: # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this if u.arg[0][0] == 'i': self.local_size = None diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 42d47424f8..82fb9faee0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -127,7 +127,7 @@ class CStyleLanguage(Renderer): # mark buffers that we store to writable if u.op is Ops.STORE: - for up in u.src[0].sparents: + for up in u.src[0].toposort: if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True)) # naming diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 5204b27f1f..90f89006f6 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -81,17 +81,17 @@ class ShapeTracker: if c.op is Ops.RANGE: ret[c.arg[0]] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg - used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE] + used_ranges = [x.arg[0] for x in idx.toposort if x.op is Ops.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: - for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None + for masked_axis in [x.arg[0] for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None return tuple(ret) def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() - return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE] + return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 2d93b4ced4..be968cbac9 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -60,7 +60,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} - for u in x.sparents: + for u in x.toposort: if u.op is Ops.CONST: continue label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): @@ -92,7 +92,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) # sanity check if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") # update ret data - g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST]) + g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST]) g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) g.graphs.append(sink:=new_sink) return g From 275951b7302f2ad423f886daa25583fd8e2063c4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:59:31 +0800 Subject: [PATCH 11/23] clean up a few parents -> toposort [pr] (#7984) * clean up a few parents -> toposort [pr] * rename to old_parents + sched tests * a few more * that one * second to last * final --- test/test_conv_shapetracker.py | 4 ++-- test/test_dtype_alu.py | 2 +- test/test_schedule.py | 6 +++--- test/test_search.py | 2 +- test/test_uop_graph.py | 6 +++--- test/test_winograd.py | 2 +- tinygrad/codegen/kernel.py | 9 +++++---- tinygrad/codegen/lowerer.py | 2 +- tinygrad/engine/schedule.py | 2 +- tinygrad/ops.py | 4 +--- tinygrad/renderer/ptx.py | 2 +- 11 files changed, 20 insertions(+), 21 deletions(-) diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 4333861a38..36e9956d1e 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -17,7 +17,7 @@ class TestConvShapetracker(unittest.TestCase): # run it again to get the kernels sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" - for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]: + for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]: assert len(st.views) == 1 def test_conv_2x2_backward_one_view(self): @@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase): conv(X).mean().backward() si = X.grad.schedule()[-1] print(si) - ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0] + ldb = [x for x in si.ast.toposort if x.op is Ops.LOAD][0] st: ShapeTracker = ldb.st_arg.simplify() # NOTE: st.real_size() is broken print(si.inputs[0].size) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 0f23e71deb..3d46e26dbb 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op): np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) else: np.testing.assert_equal(tensor_value, numpy_value) if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends - op = [x for x in ast.parents if x.op in GroupOp.Unary][0] + op = [x for x in ast.toposort if x.op in GroupOp.Unary][0] assert op.dtype == dtype def universal_test_cast(a, in_dtype, dtype): diff --git a/test/test_schedule.py b/test/test_schedule.py index 6dd758171d..1879e8bc26 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -199,7 +199,7 @@ class TestSchedule(unittest.TestCase): r1 = (x - r0).sum(axis=0).div(2) out = r0 + r1 schedule = check_schedule(out, 2) - reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS] + reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS] assert len(reduceops) == 2 def test_cache_reduce_multiple_children(self): @@ -210,7 +210,7 @@ class TestSchedule(unittest.TestCase): out0 = r0 + y out1 = r1 + y schedule = check_schedule([out0, out1], 4) - reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS] + reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS] assert len(reduceops) == 2 def test_fold_double_unary(self): @@ -1755,7 +1755,7 @@ class TestSwizzle(unittest.TestCase): # EXPAND is rewritten self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape)) # and pushed to the LOAD - new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st) + new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st) self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape)) self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) diff --git a/test/test_search.py b/test/test_search.py index a9e6d4e6f0..37747b2106 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -18,7 +18,7 @@ class TestTimeLinearizer(unittest.TestCase): def test_reasonable_time(self): si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0] out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate() - memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD} + memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD} rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))] tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True) assert tm > 0 and tm != float('inf') diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index f0b04db446..aa0e9aa26e 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -672,7 +672,7 @@ class TestIFUOps(unittest.TestCase): store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf)) sink = UOp(Ops.SINK, dtypes.void, (store,)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is Ops.IF] + if_uops = [u for u in sink.toposort if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: @@ -690,7 +690,7 @@ class TestIFUOps(unittest.TestCase): stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is Ops.IF] + if_uops = [u for u in sink.toposort if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: @@ -706,7 +706,7 @@ class TestIFUOps(unittest.TestCase): stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is Ops.IF] + if_uops = [u for u in sink.toposort if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: diff --git a/test/test_winograd.py b/test/test_winograd.py index c1a072b602..0acbee9de5 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -24,7 +24,7 @@ class TestWinograd(unittest.TestCase): for i,s in enumerate(sched): if s.ast.op is not Ops.SINK: continue - ops = s.ast.parents + ops = s.ast.toposort with Timing(f"linearize {i} with {len(ops):4d} ops: "): l = Kernel(s.ast) l.hand_coded_optimizations() diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1d069792dd..d048b4f168 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -68,10 +68,11 @@ class Kernel: self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS]) self.vars: List[Variable] = self.ast.variables() - self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer] + # NOTE: this requires a specific order with the [::-1], this is likely a bug + self.bufs: List[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1] # get earlybufs, before any reduceops - earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer] + earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer] self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0 # NOTE: full_shape can be wrong if there's a tree of reduces @@ -597,7 +598,7 @@ class Kernel: @functools.cached_property def name(self) -> str: # kernel name (before late upcast) - kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E") + kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E") suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())]) name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix @@ -712,7 +713,7 @@ class Kernel: # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # TODO: these max and min don't work on symbolic, and results are very wrong. mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group) - for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL], + for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL], key=lambda x: (x.op, x.src[0].arg))) return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index b9737f9ec0..981aad7eeb 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -56,7 +56,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: first_upcasted = len(full_shape)-ki.upcasted # if there's no reduce, this is first_upcasted. assumes reduces are at the end first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS)) - local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL] + local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL] # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)]) global_dims = first_reduce-ki.local_dims diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6c2f635fd8..397b07c801 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -126,7 +126,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" + assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg) # push VIEW to stores diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e118cdc0ac..87e25f4b2a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -248,8 +248,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg - @functools.cached_property - def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}} @functools.cached_property def toposort(self) -> Dict[UOp, None]: @@ -1068,7 +1066,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: def _valid_priority(v: UOp, valids:List[UOp]): # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified - try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids) + try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids) except ValueError: return 0 def simplify_valid(valid:UOp) -> Optional[UOp]: diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 1de0a2878d..2e91608b25 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -54,7 +54,7 @@ ptx_matcher = PatternMatcher([ (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), ]) -def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global' +def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global' def render_wmma(ctx: "PTXRenderer", x: UOp): assert ctx.wmma_r, "registry values for wmma must be populated" From 61b2cac50735c138bdb534a1ec83c64f67e1929b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:48:39 +0800 Subject: [PATCH 12/23] basicblock is dataclass (#7985) * basicblock is dataclass [pr] * tiny cleanups --- tinygrad/codegen/linearize.py | 36 ++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index e4d65b2148..98b446596d 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,5 +1,6 @@ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Optional import collections +from dataclasses import dataclass from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import dedup, flatten, partition @@ -13,9 +14,11 @@ def disp(y:UOp) -> str: if y.op is Ops.RANGE: return str(y.arg[0]) return "" +@dataclass(frozen=True) class BasicBlock: - def __init__(self, ctx, lst, end=None): - self.ctx, self.lst, self.end = ctx, lst, end + ctx: Tuple[UOp, ...] + lst: Tuple[UOp, ...] + end: Optional[UOp] = None def __repr__(self): return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\ f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst]) @@ -40,17 +43,18 @@ def append_to_block(ctx, x:UOp): if len(to_append) == 0 and len(new_blocks) == 0: return None for rng,lst in new_blocks.items(): - new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, lst)) + new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst))) lrng = list(rng) for r in rng[::-1]: if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART: lrng.remove(r) - new_block = UOp(Ops.BLOCKEND, src=(new_block,), arg=BasicBlock(lrng[:], [UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))], r)) + new_block = UOp(Ops.BLOCKEND, src=(new_block,), + arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r)) new_srcs.append(new_block) - return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, to_append+x.arg.lst)) + return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst)) make_basic_blocks = PatternMatcher([ - (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock([], [x]))), + (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))), (UPat(Ops.BLOCK, name="x"), append_to_block), ]) @@ -67,12 +71,12 @@ def block_merge(ctx, x:UOp): # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if) early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src) return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src, - BasicBlock([y for y in x.arg.ctx if y is not x.arg.end], early_ops+parent_block.arg.lst+late_ops)) + BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops))) assert not len(parent_blocks) new_srcs: List[UOp] = [] to_append: List[UOp] = [] - new_ctx = list(x.arg.ctx[:]) + new_ctx = x.arg.ctx placed = set() for u in x.src: if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)): @@ -88,7 +92,7 @@ def block_merge(ctx, x:UOp): # keep it in srcs new_srcs.append(u) if len(to_append) == 0 and len(placed) == 0: return None - return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(dedup(new_ctx), to_append+x.arg.lst, x.arg.end)) + return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(new_ctx)), tuple(to_append)+x.arg.lst, x.arg.end)) pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) @@ -131,11 +135,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # add BLOCKFORK (slow!) block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK])) - non_block_parents = flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]) - forks = {} - for u,child_count in block_parent_count.items(): - if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents: - forks[u] = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], [u])),), arg=child_count) + non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK])) + forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count) + for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents} if not len(forks): break sink = sink.substitute(forks) @@ -148,8 +150,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: for k,v in blockends_to_arg.items(): # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails if len(v) > 1: - new_blockend = UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)), arg=BasicBlock(dedup(flatten([y.arg.ctx for y in v])), v[0].arg.lst, k)) - out = UOp(Ops.BLOCKFORK, src=(new_blockend,), arg=len(v)) + out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)), + arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v)) for u in v: new_forks[u] = out sink = sink.substitute(new_forks) From 1ea09257442a46e13ee8213d54d8008f05fcf529 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 2 Dec 2024 09:59:05 +0100 Subject: [PATCH 13/23] Support packed types in smem in webgpu --- test/test_uops.py | 10 ++++++++++ tinygrad/renderer/wgsl.py | 8 ++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index d607f27fd5..effcc17c8e 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -315,6 +315,16 @@ class TestLocalAccess(unittest.TestCase): sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) + # NOTE: webgpu specific, since only webgpu performs bitpacking for uchar + @unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type") + def test_local_packed(self): + uops = [] + smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16)) + st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42))) + barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) + sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr)) + self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 9433849cf0..578a7422d2 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage): (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast({x.arg})" \ if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"), - (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{type_map[x.dtype.base]}, {x.arg[1]}>;"), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base, True)}, {x.arg[1]}>;"), (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"), @@ -78,7 +78,7 @@ class WGSLRenderer(CStyleLanguage): lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"), (UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\ # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] - f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ + f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ else f"{ctx[b]} = {ctx[v]};"), # fix nan check: 'a != a -> is_nan()' (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"), @@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage): def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_dtype(self, dt:DType, mutable=True) -> str: return "var" - def render_buf(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}" + def render_buf_dt(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] if not local_size: local_size = [1] @@ -99,6 +99,6 @@ class WGSLRenderer(CStyleLanguage): prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + f"{'var' if isinstance(dtype, PtrDType) else 'var'}" + - f"{name}:{f'array<{self.render_buf(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs]) + f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}" From dfae03858000f1ea1db7dbdb91831ab2b16cab86 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 2 Dec 2024 10:27:59 +0100 Subject: [PATCH 14/23] Simplify render_buf_dt --- tinygrad/renderer/wgsl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 578a7422d2..1eab1159cb 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage): (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast({x.arg})" \ if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"), - (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base, True)}, {x.arg[1]}>;"), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base)}, {x.arg[1]}>;"), (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"), @@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage): def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})" def render_dtype(self, dt:DType, mutable=True) -> str: return "var" - def render_buf_dt(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}" + def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] if not local_size: local_size = [1] From e2916ff210e294d208a0025567648996d78eda62 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 2 Dec 2024 05:25:13 -0500 Subject: [PATCH 15/23] image dtype fixup refactor for delete_lazy [pr] (#7989) --- tinygrad/engine/schedule.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 397b07c801..de262823e9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -54,28 +54,28 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st) return ret - # make things that can't be images not images - if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or - not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}") - # hack the underlying buffer too - buf.dtype = buf.buffer.dtype = buf.dtype.base - assert not buf.is_realized, "can't fixup allocated buffer" - buf.buffer.options = None assert buf.op is not None, f"base must be base itself {buf}" - dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base + # make things that can't be images not images + dtype = buf.dtype + if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): + assert buf.realized is None, "can't fixup allocated buffer" + if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") + dtype = buf.dtype.base + # hack the underlying buffer too + buf.buffer.dtype = buf.dtype = dtype + buf.buffer.options = None if buf.is_realized: - ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype) + ubuf = UOp.new_buffer(buf.device, buf.size, dtype) buffers[ubuf] = buf.buffer op = None elif buf.op is Ops.ASSIGN: target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs] ctx.assigns.add(ubuf:=target.buf_uop) - op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg) + op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg) else: - ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype) + ubuf = UOp.new_buffer(buf.device, buf.size, dtype) buffers[ubuf] = buf.buffer - op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg) + op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg) cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) if op is not None: ctx.lazybufs[ubuf] = buf From 8909dbd82c811628e96e390ad38658148338cd96 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 2 Dec 2024 11:31:14 +0100 Subject: [PATCH 16/23] Remove wgpu specific checks from stable diffusion example (#7991) --- examples/stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index ef54f8a888..be4305771d 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -189,7 +189,7 @@ class StableDiffusion: # make image correct size and scale x = (x + 1.0) / 2.0 x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255 - return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x + return x.cast(dtypes.uint8) def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance): e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance) @@ -280,7 +280,7 @@ if __name__ == "__main__": print(x.shape) # save image - im = Image.fromarray(x.numpy().astype(np.uint8, copy=False)) + im = Image.fromarray(x.numpy()) print(f"saving {args.out}") im.save(args.out) # Open image. From 0c7477b1085bf094413d050f05b7c88d3e3bddd4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:05:16 +0800 Subject: [PATCH 17/23] no bool in range [pr] (#7988) * no bool in range [pr] * fix llvm * add arg to range spec * fix broken test * forgot this one * hotfix: test_tiny jit is a real test --- test/test_linearizer.py | 44 +++++++++++------------------------ test/test_tiny.py | 15 ++++++++---- test/test_uop_graph.py | 4 ++-- tinygrad/codegen/linearize.py | 2 +- tinygrad/codegen/lowerer.py | 6 ++--- tinygrad/ops.py | 2 +- tinygrad/renderer/llvmir.py | 12 +++++----- 7 files changed, 38 insertions(+), 47 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 10b1a6e294..076555291a 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -100,6 +100,14 @@ class TestLinearizer(unittest.TestCase): assert len(mutable_bufs) == len(stores) == 2 assert [u.arg for u in mutable_bufs] == [0, 1] + def _test_no_nested_ranges(self, lins, skip=None): + for l in lins: + range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_ACC]) + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)] + for i,u in enumerate(ranges): + if skip and i in skip: continue + assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -130,11 +138,7 @@ class TestLinearizer(unittest.TestCase): ] wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i == 0: continue - assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -194,11 +198,7 @@ class TestLinearizer(unittest.TestCase): ] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i == 0: continue - assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0]) def test_triple_multireduce(self): Tensor.manual_seed(0) @@ -218,11 +218,7 @@ class TestLinearizer(unittest.TestCase): sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5) lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output]) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i == 0: continue - assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -270,11 +266,7 @@ class TestLinearizer(unittest.TestCase): Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], ] lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i < 2: continue - assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0, 1]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -301,11 +293,7 @@ class TestLinearizer(unittest.TestCase): ] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i == 0: continue - assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -339,11 +327,7 @@ class TestLinearizer(unittest.TestCase): ] wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1) lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts) - for l in lins: - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] - for i,u in enumerate(ranges): - if i == 0: continue - assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" + self._test_no_nested_ranges(lins, [0]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_multiout_multireduce(self): diff --git a/test/test_tiny.py b/test/test_tiny.py index 04544d825f..394aea5955 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -1,5 +1,5 @@ # basic self-contained tests of the external functionality of tinygrad -import unittest +import unittest, random from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device from tinygrad.helpers import IMAGE @@ -41,14 +41,21 @@ class TestTiny(unittest.TestCase): def test_jit(self): cnt = 0 + def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)] + @TinyJit - def fxn(a,b): + def fxn(a,b) -> Tensor: nonlocal cnt cnt += 1 return a+b + for _ in range(3): - fa,fb = Tensor([1.,2,3]), Tensor([4.,5,6]) - fxn(fa, fb) + la,lb = new_rand_list(), new_rand_list() + fa,fb = Tensor(la), Tensor(lb) + ret = fxn(fa, fb) + # math is correct + self.assertListEqual(ret.tolist(), [a+b for a,b in zip(la, lb)]) + # function is only called twice self.assertEqual(cnt, 2) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index aa0e9aa26e..5311d0489f 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -433,8 +433,8 @@ class TestUOpGraph(unittest.TestCase): c0 = UOp.const(dtypes.int, 0) c2 = UOp.const(dtypes.int, 2) cf = UOp.const(dtypes.float, 0.0) - r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False)) - r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False)) + r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 0) + r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 1) alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 98b446596d..c571e7f476 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -11,7 +11,7 @@ DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops. def disp(y:UOp) -> str: if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0]) if y.op is Ops.IF: return f'IF{id(y)}' - if y.op is Ops.RANGE: return str(y.arg[0]) + if y.op is Ops.RANGE: return str(y.arg) return "" @dataclass(frozen=True) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 981aad7eeb..c71317cca7 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -71,10 +71,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) else: # all loops are RANGES - idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])] + idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])] # reduce loops - idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True)) + idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] # upcast loops @@ -85,7 +85,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: # late indexes (group for reduce) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): - ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True)) + ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a) return IndexContext(idxs, ridxs) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 87e25f4b2a..16ae697f72 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -807,7 +807,7 @@ spec = PatternMatcher([ lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), + (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)), (UPat(Ops.SPECIAL, src=()), lambda: True), # TODO: confirm the args of both of these are shapetrackers diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index f9f7d59a5f..c90e76da8a 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -60,13 +60,13 @@ llvm_rewrite = PatternMatcher([ # range (UPat(Ops.RANGE, name="x"), lambda ctx,x: - f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n" - f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n" - f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"), + f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n" + f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n" + f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"), (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: - f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n" + f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n" f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n" - f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"), + f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"), # if (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), @@ -131,7 +131,7 @@ class LLVMRenderer(Renderer): for x in acc_to_assign: if u in x.src: # if this range is relevent for this acc vc += 1 - kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]") + kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]") r[x] = f"%acc{vc}" # output the function From bb606e5bcf27d962884280298769d0be1281f7bc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 2 Dec 2024 06:38:31 -0500 Subject: [PATCH 18/23] process replayable ops.py changes from delete_lazy [pr] (#7994) * process replayable ops.py changes from delete_lazy [pr] * hotfix: seed tiny_jit --- test/test_tiny.py | 1 + tinygrad/ops.py | 15 +++++---------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/test/test_tiny.py b/test/test_tiny.py index 394aea5955..3bd2fda82a 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -41,6 +41,7 @@ class TestTiny(unittest.TestCase): def test_jit(self): cnt = 0 + random.seed(0) def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)] @TinyJit diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 16ae697f72..07b30a7967 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -277,7 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape) @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: - return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) + return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) @property def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape @property @@ -338,8 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype - if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None: - out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool + if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(arg, out_dtype, (self,)+src) @staticmethod def const(dtype:DType, b:ConstLike): @@ -384,13 +383,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): buffer_num = itertools.count(0) @staticmethod - def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype))) + def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype))) @functools.cached_property - def device(self) -> str: - match self.op: - case Ops.COPY: return self.arg - case Ops.BUFFER: return self.arg[1][0] - case _: return self.src[0].device + def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self @@ -627,7 +622,7 @@ class UPat(MathTrait): def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, op:Ops, *src:UPat): asrc = (self,)+src - return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) + return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() From 077e7e8ed2b2f0c211ee86f877b27e6147893f61 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Mon, 2 Dec 2024 07:54:50 -0500 Subject: [PATCH 19/23] fix: private segment sgpr on gfx103x (#7987) Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- tinygrad/runtime/ops_amd.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 043818dd6f..62c53c506b 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -82,7 +82,12 @@ class AMDComputeQueue(HWQueue): def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]): self.acquire_mem(gli=0, gl2=0) - user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else [] + if prg.enable_private_segment_sgpr: + scratch_hilo = data64_le(prg.dev.scratch.va_addr) + # sgpr word1 bit31 enables swizzle + # sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23 + user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else [] + else: user_regs = [] if prg.enable_dispatch_ptr: dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size) @@ -370,12 +375,16 @@ class AMDDevice(HCQCompiled): max_cu_id = self.properties['simd_count'] // self.properties['simd_per_cu'] - 1 max_wave_id = self.properties['max_waves_per_simd'] * self.properties['simd_per_cu'] - 1 self.max_private_segment_size = 4096 - wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256 + # =gfx11 requires 256 + wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024) self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) self.has_scratch_base_registers = self.target >= 110000 engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine'] - self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines)) + waves = wave_scratch_len // (256 if self.target >= 110000 else 1024) + # >=gfx11 wavesize is per SE + wavesize = self.scratch_len // ((wave_scratch_len * engines) if self.target >= 110000 else wave_scratch_len) + self.tmpring_size = waves << 12 | wavesize # https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391 sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000 From 146e1caea314f4db24f7d2541eb3e70fdeb674b5 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 2 Dec 2024 15:48:44 +0100 Subject: [PATCH 20/23] Downgrade wgpu to prevent sd segfault (#7969) --- setup.py | 2 +- tinygrad/runtime/ops_webgpu.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index a9ec0db3ff..d7e2ac6b16 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ setup(name='tinygrad', "bottle", "ggml-python" ], - 'webgpu': ["wgpu>=v0.19.0"], + 'webgpu': ["wgpu==v0.18.1"], 'docs': [ "mkdocs", "mkdocs-material", diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 0e61c2553a..5a2e33c014 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -58,8 +58,8 @@ class WebGpuAllocator(Allocator): class WebGpuDevice(Compiled): def __init__(self, device:str): - adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") + adapter = wgpu.gpu.request_adapter(power_preference="high-performance") timestamp_supported = wgpu.FeatureName.timestamp_query in adapter.features - wgpu_device = adapter.request_device_sync(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else []) + wgpu_device = adapter.request_device(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else []) super().__init__(device, WebGpuAllocator(wgpu_device), WGSLRenderer(), Compiler(), functools.partial(WebGPUProgram, (wgpu_device, timestamp_supported))) From 0a2e10be1d944a33c8bfa2e46d146e90c0116c4f Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:04:01 +0800 Subject: [PATCH 21/23] add SELU to Tensor (#7993) * add selu * more clean ups --- docs/tensor/elementwise.md | 1 + extra/onnx_ops.py | 3 +-- test/test_ops.py | 3 +++ tinygrad/tensor.py | 13 +++++++++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/docs/tensor/elementwise.md b/docs/tensor/elementwise.md index fc5bd70d4e..3d3858ad79 100644 --- a/docs/tensor/elementwise.md +++ b/docs/tensor/elementwise.md @@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t ::: tinygrad.Tensor.hardsigmoid ::: tinygrad.Tensor.elu ::: tinygrad.Tensor.celu +::: tinygrad.Tensor.selu ::: tinygrad.Tensor.swish ::: tinygrad.Tensor.silu ::: tinygrad.Tensor.relu6 diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 10d58b42f4..ae5afedca2 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -8,7 +8,7 @@ import numpy as np tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", - "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"} + "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"} # **************** Free Ops **************** @@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) -def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) def PRelu(X:Tensor, slope:Tensor): slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE return (X > 0).where(X, X * slope) diff --git a/test/test_ops.py b/test/test_ops.py index 2dd89ca7e3..a376eea19a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -697,6 +697,9 @@ class TestOps(unittest.TestCase): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) + def test_selu(self): + helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu) + helper_test_op([()], torch.nn.functional.selu, Tensor.selu) def test_abs(self): helper_test_op([(45,65)], torch.abs, Tensor.abs) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 30a1f5f60c..24f076cc41 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2688,6 +2688,19 @@ class Tensor(SimpleMathTrait): """ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) + def selu(self, alpha=1.67326, gamma=1.0507): + """ + Applies the Scaled Exponential Linear Unit (SELU) function element-wise. + + - Described: https://paperswithcode.com/method/selu + - Paper: https://arxiv.org/abs/1706.02515v5 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy()) + ``` + """ + return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) + def swish(self): """ See `.silu()` From b91fa243876e9e15a3e13ea599762e7f7385c89f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Dec 2024 15:34:27 -0500 Subject: [PATCH 22/23] script to run regressed sd conv on metal (#7995) * script to run regressed sd conv on metal this and other similar `conv2d + add` kernels contributed to most of the speed regression * # ruff: noqa: E501 --- test/external/external_debug_metal_sd_conv.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 test/external/external_debug_metal_sd_conv.py diff --git a/test/external/external_debug_metal_sd_conv.py b/test/external/external_debug_metal_sd_conv.py new file mode 100644 index 0000000000..8763e0cbc9 --- /dev/null +++ b/test/external/external_debug_metal_sd_conv.py @@ -0,0 +1,46 @@ +# ruff: noqa: E501 +from tinygrad.codegen.kernel import Kernel, Opt, OptOps +from tinygrad.dtype import dtypes +from tinygrad.engine.realize import CompiledRunner +from tinygrad.engine.search import bufs_from_lin +from tinygrad.helpers import Timing +from tinygrad.ops import UOp, Ops +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View + +ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(81920, 0, 64, 8, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ADD, dtypes.half, arg=None, src=( + UOp(Ops.ADD, dtypes.half, arg=None, src=( + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 2560, 4, 10, 4, 10), strides=(0, 163840, 0, 64, 0, 8, 0, 1), offset=-9, mask=((0, 1), (0, 2), (0, 1), (0, 2560), (0, 4), (1, 9), (0, 4), (1, 9)), contiguous=False), View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(4096000, 0, 0, 40, 1, 1600, 440, 11), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(0, 0, 23040, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), + x17:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()), + x17,)),)),)),)) +opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)] + +k = Kernel(ast) +for opt in opts: k.apply_opt(opt) +bufs = bufs_from_lin(k) + +prg = CompiledRunner(k.to_program()) + +with Timing("run "): + prg(bufs, var_vals={}, wait=True) + +# on M1 Max +# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229 +# 15ms after block cbcc1c20eb09a1342f6581cfbb99632bade982a8 \ No newline at end of file From c7bc75e6340f1d74aee17244e90b43daad2bd70f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Dec 2024 17:19:27 -0500 Subject: [PATCH 23/23] alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) (#7900) * alu(c?t0:f0, c?t1:f1) -> c?alu(t0,t1):alu(f0,f1) only do if at least one branch is const, so total alu won't increase * tests and interesting TODO cases --- test/unit/test_uop_symbolic.py | 24 ++++++++++++++++++++++++ tinygrad/ops.py | 3 +++ 2 files changed, 27 insertions(+) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 255c94fbcf..e720712042 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)") self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)") + def test_where_combine(self): + cond = Variable("x", 0, 3).lt(2) + a = Variable("a", 0, 3) + b = Variable("b", 0, 3) + aa = cond.where(a, a.ufix(0)) + bb = cond.where(b, b.ufix(1)) + self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)") + self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)") + self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)") + self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)") + + # not combining because it increased total ALU + c = Variable("c", 0, 3) + cc = cond.where(c, c+1) + self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))") + + # not combining # TODO: can combine if it can further simplify? + ab = cond.where(a, b) + ba = cond.where(b, a) + self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))") + + # not combining # TODO: can combine if one is identity element const + self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 07b30a7967..2bc2999733 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1132,6 +1132,9 @@ symbolic = symbolic_simple+PatternMatcher([ # a conditional with the same results either way is a noop, also fold const conditionals (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), + # alu of two where with same conds can combine, only do if true branch or false branch is const + (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ + lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # ALU min==max -> CONST (slow!) (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding