diff --git a/tinygrad/codegen/late/linearizer.py b/tinygrad/codegen/late/linearizer.py index e9478a7906..c44aa3f7c4 100644 --- a/tinygrad/codegen/late/linearizer.py +++ b/tinygrad/codegen/late/linearizer.py @@ -1,7 +1,8 @@ import heapq +from typing import Any from collections import defaultdict from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str -from tinygrad.helpers import prod, getenv +from tinygrad.helpers import prod, getenv, TUPLE_ORDER def linearize(sink:UOp) -> list[UOp]: # this is a toposort with priority @@ -9,7 +10,7 @@ def linearize(sink:UOp) -> list[UOp]: consumers: defaultdict[UOp, list[UOp]] = defaultdict(list) in_degree:dict[UOp, int] = {} out_degree:dict[UOp, int] = {} - priorities:dict[UOp, tuple[int, int]] = {} + priorities:dict[UOp, tuple[int, int, Any]] = {} # get consumers and assign priorities # NOTE: this requires the lst be locally toposorted @@ -22,19 +23,23 @@ def linearize(sink:UOp) -> list[UOp]: run_count = prod([int(r.vmax)+1 for r in u.ranges]) # simple priority override. this is all bottom up now, smaller numbers will be closer to the top + extra = None match u.op: # the order and placement of these defines is important - case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20 + case Ops.DEFINE_GLOBAL: priority, extra = -20, u.arg + case Ops.DEFINE_VAR: priority, extra = -19, u.arg + case Ops.DEFINE_LOCAL: priority = -18 + case Ops.DEFINE_REG: priority = -17 case Ops.CONST: priority = -10 # early consts case Ops.LOAD: priority = -1 # place loads early case Ops.STORE: priority = 1 # place stores late case Ops.RANGE: priority = 5 # placing RANGE is good case Ops.END: priority = -5 # placing END is bad case _: priority = 0 # everything else has priority 0 - priorities[u] = (run_count, priority) + priorities[u] = (run_count, priority, extra) # number the uops in "ideal" order - nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+x.tuplize))} + nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))} # then force then to be toposorted in as close to the ideal order as possible heap = [(-nkey[sink], sink)] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 9ff058ffdb..013d6d53e1 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -179,6 +179,8 @@ SPEC = ContextVar("SPEC", 1) IGNORE_OOB = ContextVar("IGNORE_OOB", 1) PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0) +# set to 1, this uses tuplize in the linearizer sort order +TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1) @dataclass(frozen=True) class Metadata: