diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2617ef5044..7563409a94 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -382,7 +382,7 @@ jobs: cache-name: cache-gpuocelot-build with: path: ${{ github.workspace }}/gpuocelot/ocelot - key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-5 + key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-6 - name: Clone/compile gpuocelot if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv') && steps.cache-build.outputs.cache-hit != 'true' run: | diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 95b0c517c1..6b3209df16 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -50,9 +50,9 @@ class Linearizer(Kernel): return self.uops.add(UOps.ALU, dtype, (a, render_b), op) # NOTE: the consts have to be cached for deduping of downstream uops to work - def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before=None) -> UOp: - if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0], insert_before=insert_before) - else: return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) + def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp: + if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0]) + else: return self.uops.add(UOps.CONST, dtype, tuple(), b) def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val @@ -451,8 +451,9 @@ class Linearizer(Kernel): def to_program(self) -> Program: self.linearize() info = get_lazyop_info(self.ast[0]) + src = self.opts.render(to_function_name(self.name), self.uops) ops, mem = self.uops.flops_mem() run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS - return Program(self.name, self.opts.render(to_function_name(self.name), self.uops), self.opts.device, - self.global_size, self.local_size, self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) + return Program(self.name, src, self.opts.device, self.global_size, self.local_size, + self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 2bf427c0e4..11ca5495aa 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Tuple, Any, Dict, List, DefaultDict, Set +from typing import Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable import functools, itertools, heapq from collections import defaultdict from enum import Enum, auto @@ -74,7 +74,6 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: if k == "__name__": if v in store and store[v] != uop: return False store[v] = uop - elif k[:2] == "__": continue elif k == "vin": # only one if it's a tuple # try all permutations if it's a list @@ -88,17 +87,18 @@ def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: return False elif k in {"dtype", "uop"}: if uop.__getattribute__(k) not in (v if isinstance(v, set) else set([v])): return False + elif k[:2] == "__": continue else: if uop.__getattribute__(k) != v: return False return True class PatternMatcher: - def __init__(self, patterns:List[Tuple[Dict[str, Any], Any]]): + def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]): self.patterns = patterns - self.pdict = defaultdict(list) + self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list) # uop is required, arg is optional for p,fxn in self.patterns: - uops = p.get("uop") + uops = p["uop"] if isinstance(uops, set): for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn)) else: @@ -110,14 +110,6 @@ class PatternMatcher: if _match(uop, p, store): return fxn(**store) return None - def recursive_rewrite(self, uop:UOp) -> UOp: - run_cnt = 0 - while (rewritten := self.rewrite(uop)): - assert run_cnt < 100, f"recursive_rewrite looped {uop} <--> {rewritten}" - uop = rewritten - run_cnt += 1 - return uop - def sum_collapse(phi_input, loop, val1, val2): for v1,v2 in [(val1, val2), (val2, val1)]: if loop not in v1.parents: @@ -259,19 +251,7 @@ class UOpGraph: for i,u in enumerate(self): print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") - def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True): - # NOTE: relinearizering should be okay - #assert self._uops is None, "already linearized" - pm = PatternMatcher(constant_folder.patterns+extra_pm.patterns) if extra_pm is not None else constant_folder - - # get sink - _sinks: List[UOp] = [] - for u in self.nodes.values(): - if u.uop is UOps.STORE: _sinks.append(u) - if u.uop is UOps.SINK: _sinks.extend(u.vin) - sink = UOp(UOps.SINK, None, tuple(_sinks)) - del _sinks - + def graph_rewrite(self, sink, pm): # recursive rewrite changed = getenv("UOPS_REWRITE", 1) run_cnt = 0 @@ -280,8 +260,14 @@ class UOpGraph: @functools.lru_cache def rewrite(u:UOp) -> UOp: nonlocal changed - up = pm.recursive_rewrite(u) - if up != u: changed += 1 + recurse_cnt = 0 + up = u + # locally recursively rewrite + while (rewritten := pm.rewrite(up)): + assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}" + up = rewritten + recurse_cnt += 1 + changed += recurse_cnt up.vin = tuple(rewrite(x) for x in up.vin) if hasattr(up, "parents"): del up.parents # replace with cached nodes @@ -291,27 +277,41 @@ class UOpGraph: sink = rewrite(sink) run_cnt += 1 assert run_cnt < 100, "exceeded 100 rewrite loops!" + return sink + + def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True): + # NOTE: relinearizering should be okay + #assert self._uops is None, "already linearized" + + # get sink + _sinks: List[UOp] = [] + for u in self.nodes.values(): + if u.uop is UOps.STORE: _sinks.append(u) + if u.uop is UOps.SINK: _sinks.extend(u.vin) + sink = UOp(UOps.SINK, None, tuple(_sinks)) + del _sinks + + sink = self.graph_rewrite(sink, constant_folder) + if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns)) # filter nodes that don't link to a sink - nodes: Dict[UOp, None] = {} - def add_parents(u:UOp): - if u in nodes: return - nodes[u] = None - for x in u.vin: add_parents(x) - sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP)) - add_parents(sink) - # BFS toposort graph: DefaultDict[UOp, List[UOp]] = defaultdict(list) in_degree: DefaultDict[UOp, int] = defaultdict(int) loops = [] ifs = [] - for u in nodes: + nodes: Dict[UOp, None] = {} + def add_parents(u:UOp): + if u in nodes: return + nodes[u] = None for x in u.vin: + add_parents(x) in_degree[u] += 1 graph[x].append(u) if u.uop is UOps.LOOP: loops.append(u) if u.uop is UOps.IF: ifs.append(u) + sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP)) + add_parents(sink) @functools.lru_cache(None) def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]: @@ -355,8 +355,7 @@ class UOpGraph: if type_verify: self.type_verify() - def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, - cachable=True, insert_before=None, simplify=True) -> UOp: + def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp: if uop is UOps.CONST: assert dtype is not None arg = dtypes.as_const(arg, dtype) # TODO: this doesn't belong here diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 4762dc421d..938a8fdf39 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -259,6 +259,11 @@ ptx_matcher = PatternMatcher([ ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})}, lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)), # ptr_ar (load/store) + ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}}, + {"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})}, + lambda root, alu, const: UOp(root.uop, root.dtype, + (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64), + UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])), ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}}, {"__name__": "const", "uop":UOps.CONST})}, lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),