mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tuplize from linearizer behind flag (#13136)
* remove tuplize from linearizer * optional tuplize
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import heapq
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
@@ -9,7 +10,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
||||
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
out_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, tuple[int, int]] = {}
|
||||
priorities:dict[UOp, tuple[int, int, Any]] = {}
|
||||
|
||||
# get consumers and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
@@ -22,19 +23,23 @@ def linearize(sink:UOp) -> list[UOp]:
|
||||
run_count = prod([int(r.vmax)+1 for r in u.ranges])
|
||||
|
||||
# simple priority override. this is all bottom up now, smaller numbers will be closer to the top
|
||||
extra = None
|
||||
match u.op:
|
||||
# the order and placement of these defines is important
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20
|
||||
case Ops.DEFINE_GLOBAL: priority, extra = -20, u.arg
|
||||
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
|
||||
case Ops.DEFINE_LOCAL: priority = -18
|
||||
case Ops.DEFINE_REG: priority = -17
|
||||
case Ops.CONST: priority = -10 # early consts
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.STORE: priority = 1 # place stores late
|
||||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
case Ops.END: priority = -5 # placing END is bad
|
||||
case _: priority = 0 # everything else has priority 0
|
||||
priorities[u] = (run_count, priority)
|
||||
priorities[u] = (run_count, priority, extra)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+x.tuplize))}
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
|
||||
|
||||
# then force then to be toposorted in as close to the ideal order as possible
|
||||
heap = [(-nkey[sink], sink)]
|
||||
|
||||
@@ -179,6 +179,8 @@ SPEC = ContextVar("SPEC", 1)
|
||||
IGNORE_OOB = ContextVar("IGNORE_OOB", 1)
|
||||
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
|
||||
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
|
||||
# set to 1, this uses tuplize in the linearizer sort order
|
||||
TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user