mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use existing VIZ ContextVar instead of getenv (#14610)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
from dataclasses import replace
|
||||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, TracingKey, Context
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
@@ -25,7 +25,7 @@ from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_c
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(sink))
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.device import Device
|
||||
|
||||
@@ -224,7 +224,7 @@ multi_pm = PatternMatcher([
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
|
||||
if getenv("VIZ"): graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
return ret
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -561,7 +561,7 @@ replace_contiguous = PatternMatcher([
|
||||
])
|
||||
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
@@ -579,7 +579,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER, Ops.AFTER} and \
|
||||
x.tag is not None and len(x.tag)])
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
@@ -605,7 +605,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for tag, s in zip(sink_tags, tsink.src):
|
||||
|
||||
Reference in New Issue
Block a user