var_vals uses str for var (#12011)

* var_vals is str,int

* remove imports

* remove print

* fix test

* change var_vals in hcq

* update test_hcq

* fix multitensor _device_num var

* fix syminfer test

* shorten line

* p.vars stays list[Variable]

* shorten line

* vars is back to tuple[Variable, ...]

* change var_vals in extra

* change var_vals from shapetracker

* var_vals is str:int

* fix signature
This commit is contained in:
Sieds Lykles
2025-09-06 04:16:12 +02:00
committed by GitHub
parent 8658a97197
commit c6c16b2946
24 changed files with 90 additions and 91 deletions

View File

@@ -10,13 +10,13 @@ from tinygrad.renderer.cstyle import ClangRenderer
render_dtype = ClangRenderer().render_dtype
class ClangGraph(GraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
args += sorted([f"int {v.expr}" for v in var_vals])
args += sorted([f"int {v}" for v in var_vals])
code = ["void batched("+','.join(args)+") {"]
for ji in jit_cache:
args = []
@@ -34,6 +34,6 @@ class ClangGraph(GraphRunner):
assert compiler is not None
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False):
return cpu_time_execution(
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait)

View File

@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
self.available_packet_slots -= 1
class HSAGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
@@ -53,7 +53,7 @@ class HSAGraph(MultiGraphRunner):
self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr])
# Build queues.
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
@@ -106,7 +106,7 @@ class HSAGraph(MultiGraphRunner):
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]:
# Wait and restore signals
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
@@ -123,7 +123,7 @@ class HSAGraph(MultiGraphRunner):
# Update var_vals
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr])
# Update launch dims
for j in self.jc_idx_with_updatable_launch_dims:

View File

@@ -88,7 +88,7 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
return ret
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)

View File

@@ -115,7 +115,7 @@ def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = get_program(lin.get_optimized_ast(), lin.opts)
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

View File

@@ -120,7 +120,7 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
if i > 1000:
print("WARNING: did not search all possible combinations")
break
var_vals = {k:v for k,v in zip(vs, ranges)}
var_vals = {k.expr:v for k,v in zip(vs, ranges)}
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
if r1 != r2: return False

View File

@@ -52,12 +52,12 @@ class TestHCQ(unittest.TestCase):
with self.subTest(name=str(queue_type)):
q = queue_type().signal(virt_signal, virt_val)
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -106,7 +106,7 @@ class TestHCQ(unittest.TestCase):
fake_signal.value = 0x30
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -131,7 +131,7 @@ class TestHCQ(unittest.TestCase):
.signal(TestHCQ.d0.timeline_signal, virt_val)
for _ in range(100):
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
@@ -146,7 +146,7 @@ class TestHCQ(unittest.TestCase):
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
q.submit(TestHCQ.d0, {sint_global[0].expr: 1, sint_local[0].expr: 1})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -181,7 +181,7 @@ class TestHCQ(unittest.TestCase):
for z in range(1, 4):
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value, virt_local[0].expr: x, virt_local[1].expr: y, virt_local[2].expr: z})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -253,7 +253,7 @@ class TestHCQ(unittest.TestCase):
.copy(virt_dest_addr, virt_src_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr})
q.submit(TestHCQ.d0, {virt_src_addr.expr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr.expr: TestHCQ.b.uop.buffer._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -276,7 +276,7 @@ class TestHCQ(unittest.TestCase):
.copy(virt_dest_addr, virt_src_addr, sz) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
q.submit(TestHCQ.d0, {virt_src_addr.expr: buf2._buf.va_addr, virt_dest_addr.expr: buf1._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -299,7 +299,7 @@ class TestHCQ(unittest.TestCase):
fake_signal.value = 0x30
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1

View File

@@ -90,7 +90,7 @@ def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_de
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]: # (error msg, run state)
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in lin.vars}
if var_vals is None: var_vals = {v.expr: v.min for v in lin.vars}
# TODO: images needs required_optimization
try:
@@ -129,7 +129,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
if var_vals is None:
# TODO: handle symbolic max case
var_vals = {v: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()}
var_vals = {v.expr: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()}
if ground_truth is None and not has_bf16:
unoptimized = Kernel(lin.ast)

View File

@@ -476,7 +476,7 @@ class TestUOpMethod(unittest.TestCase):
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
_, var_vals = (uop_var+st_var).schedule_with_vars()
self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a)
self.assertEqual(list(var_vals)[0], a.expr)
def test_const_factor(self):
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 8),), 'gidx0')

View File

@@ -73,19 +73,19 @@ class TestSymbolicVarVals(unittest.TestCase):
def test_var_vals_shape(self):
x = Variable("x", 1, 100).bind(3)
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
assert ShapeTracker.from_shape((x, 3)).var_vals == {"x": 3}
def test_var_vals_offset(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
self.assert_equal(st.views[-1].offset, x * 3)
assert st.var_vals == {Variable("x", 1, 100): 3}
assert st.var_vals == {"x": 3}
def test_var_vals_mask(self):
x = Variable("x", 1, 100).bind(3)
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
st = ShapeTracker(views=(view,))
assert st.var_vals == {Variable("x", 1, 100): 3}
assert st.var_vals == {"x": 3}
def test_var_vals_complex(self):
x = Variable("x", 1, 100).bind(3)
@@ -93,13 +93,13 @@ class TestSymbolicVarVals(unittest.TestCase):
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
self.assert_equal(st.views[-1].offset, y * z)
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
assert st.var_vals == {"x": 3, "y": 4, "z": 5}
def test_shrink_reshape(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
st = st.reshape((3*4*3,))
assert st.var_vals == {Variable("x", 1, 100): 3}
assert st.var_vals == {"x": 3}
class TestShapeTrackerUnbind(unittest.TestCase):
def test_view_unbind(self):

View File

@@ -804,7 +804,7 @@ class TestSymInfer(unittest.TestCase):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
var_vals = {a: 2, b: 3, c: 4}
var_vals = {a.expr: 2, b.expr: 3, c.expr: 4}
assert sym_infer(5, var_vals) == 5
assert sym_infer(4.2, var_vals) == 4.2
assert sym_infer(a, var_vals) == 2
@@ -817,7 +817,7 @@ class TestSymInfer(unittest.TestCase):
def test_sym_infer_cdiv_cmod(self):
a = Variable("a", -1000, 1)
b = Variable("b", -1000, 1)
var_vals = {a: 1, b: -1000}
var_vals = {a.expr: 1, b.expr: -1000}
assert sym_infer(a%b, var_vals) == 1
assert sym_infer(a//b, var_vals) == 0

View File

@@ -1,7 +1,7 @@
from typing import cast
import functools, math, time, multiprocessing, traceback, signal, atexit
from dataclasses import replace
from tinygrad.uop.ops import Variable, sym_infer, AxisType, pyrender
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
from tinygrad.helpers import IGNORE_BEAM_CACHE
@@ -34,7 +34,7 @@ def get_test_global_size(global_size, max_global_size, var_vals):
break
return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:float|None=None,
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None:
@@ -141,7 +141,7 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:

View File

@@ -16,7 +16,7 @@ class GraphException(Exception): pass
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: list[ExecItem] = []
@@ -73,7 +73,7 @@ def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer])
return input_replace
class GraphRunner(Runner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
@@ -82,7 +82,7 @@ class GraphRunner(Runner):
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
self.vars = sorted(var_vals.keys())
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
@@ -91,7 +91,7 @@ class GraphRunner(Runner):
for j,ji in enumerate(jit_cache):
estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner):
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars]
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
if global_dim_idx is not None or local_dim_idx is not None:
@@ -105,12 +105,12 @@ class GraphRunner(Runner):
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
def updated_vars(self, var_vals: dict[Variable, int]):
def updated_vars(self, var_vals: dict[str, int]):
vals = [var_vals[v] for v in self.vars]
for j, vidxs in self.var_vals_replace.items():
for i, v in vidxs: yield j, i, vals[v]
def updated_launch_dims(self, var_vals: dict[Variable, int]):
def updated_launch_dims(self, var_vals: dict[str, int]):
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
for j, (gl, lc) in self.launch_dims_replace.items():
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
@@ -192,7 +192,7 @@ class CapturedJit(Generic[ReturnType]):
self.__post_init__()
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
# assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
@@ -225,7 +225,8 @@ def _prepare_jit_inputs(args, kwargs):
for lb in lbs if lb.base.realized is not None])
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()}
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device

View File

@@ -3,7 +3,7 @@ import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
@@ -56,9 +56,9 @@ class Runner:
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
@property
def dev(self): return Device[self.device]
def exec(self, rawbufs:list[Buffer], var_vals:dict[Variable, int]|None=None) -> float|None:
def exec(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None) -> float|None:
return self(rawbufs, {} if var_vals is None else var_vals)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
raise NotImplementedError("override this")
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
@@ -89,7 +89,7 @@ class CompiledRunner(Runner):
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
global_size, local_size = self.p.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
local_size = optimize_local_size(self._prg, global_size, rawbufs)
@@ -102,11 +102,11 @@ class CompiledRunner(Runner):
if local_size:
lra['local_size'] = tuple(local_size)
assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
class BufferCopy(Runner):
@@ -124,7 +124,7 @@ class BufferCopy(Runner):
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
dest, src = rawbufs[0:2]
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
st = time.perf_counter()
@@ -159,8 +159,8 @@ class ExecItem:
prg: Runner
bufs: list[Buffer|None]
metadata: tuple[Metadata, ...]|None = None
fixedvars: dict[Variable, int] = field(default_factory=dict)
def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
fixedvars: dict[str, int] = field(default_factory=dict)
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", self.prg.display_name, {"metadata":self.metadata, "var_vals":var_vals}))
@@ -211,7 +211,7 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem,
capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=None, do_update_stats=True):
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
for si, ei in lower_schedule(schedule):
if len(capturing) and CAPTURING: capturing[0].add(ei)
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
@@ -229,4 +229,3 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else:
ei.run(var_vals, do_update_stats=do_update_stats)

View File

@@ -1,7 +1,7 @@
from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Variable, Ops, buffers
from tinygrad.uop.ops import UOp, Ops, buffers
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.helpers import Metadata, all_same
@@ -12,15 +12,15 @@ class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[Variable, int] = field(default_factory=dict)
fixedvars: dict[str, int] = field(default_factory=dict)
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
var_vals: dict[Variable, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
@@ -40,8 +40,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND:
var, val = s.unbind()
assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}"
var_vals[var] = val
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
else:
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
@@ -72,7 +72,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))

View File

@@ -101,7 +101,7 @@ class ProgramSpec:
def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
def launch_dims(self, var_vals:dict[Variable, int]):
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size

View File

@@ -4,12 +4,11 @@ import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, dedup
from tinygrad.device import Buffer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.uop.ops import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
@@ -28,7 +27,7 @@ class CUDAGraph(MultiGraphRunner):
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
@@ -48,7 +47,7 @@ class CUDAGraph(MultiGraphRunner):
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
@@ -69,7 +69,7 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
for j,ji in enumerate(jit_cache):
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
@@ -183,7 +183,7 @@ class HCQGraph(MultiGraphRunner):
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
@@ -195,12 +195,13 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
for (j,i),input_idx in self.input_replace.items():
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))

View File

@@ -5,7 +5,6 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.uop.ops import Variable
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
@@ -17,7 +16,7 @@ class MTLResourceUsage:
MTLResourceUsageWrite = 0b10
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
@@ -48,7 +47,8 @@ class MetalGraph(GraphRunner):
if b is not None and b not in input_rawbuffers:
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.vars):
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
@@ -61,7 +61,7 @@ class MetalGraph(GraphRunner):
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = to_struct(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
# NOTE: old command buffer may not be inflight anymore
if self.command_buffer is not None and PROFILE: self.collect_timestamps()

View File

@@ -1,5 +1,4 @@
import time, itertools
from tinygrad.uop.ops import Variable
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
@@ -18,7 +17,7 @@ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
@@ -93,7 +92,7 @@ class RemoteGraph(MultiGraphRunner):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:

View File

@@ -100,7 +100,7 @@ class GraphComputeItem:
datahash: str
bufs: tuple[int, ...]
vars: tuple[Variable, ...]
fixedvars: dict[Variable, int]
fixedvars: dict[str, int]
ins: tuple[int, ...]
outs: tuple[int, ...]
global_size: tuple[sint, ...]|None
@@ -111,7 +111,7 @@ class GraphAlloc(RemoteRequest):
graph_num: int
jit_cache: tuple[GraphComputeItem|Transfer, ...]
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[Variable, int]
var_vals: dict[str, int]
@dataclass(frozen=True)
class GraphFree(RemoteRequest):
@@ -121,7 +121,7 @@ class GraphFree(RemoteRequest):
class GraphExec(RemoteRequest):
graph_num: int
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[Variable, int]
var_vals: dict[str, int]
wait: bool
# for safe deserialization

View File

@@ -6,7 +6,7 @@ except ImportError: fcntl = None #type:ignore[assignment]
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
class MMIOInterface:
@@ -192,7 +192,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
def _apply_var_vals(self, var_vals:dict[Variable, int]):
def _apply_var_vals(self, var_vals:dict[str, int]):
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
for off, sym_idx in self.q_sints:
@@ -205,7 +205,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
def submit(self, dev:HCQDeviceType, var_vals:dict[str, int]|None=None):
"""
Submits the command queue to a specific device for execution.

View File

@@ -97,7 +97,7 @@ class ShapeTracker:
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
@property
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
def var_vals(self) -> dict[str, int]: return merge_dicts([{(vu:=v.unbind())[0].expr:vu[1]} for v in self.vars()])
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata
from tinygrad.uop.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
@@ -241,7 +241,7 @@ class Tensor(MathTrait):
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
return self
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.

View File

@@ -37,7 +37,7 @@ def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x)
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
# used for UOp and UPat
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
@@ -486,7 +486,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self):
def expr(self) -> str:
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int|UOp):
@@ -584,9 +584,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
def sym_infer(self, var_vals:dict[UOp, int]):
def sym_infer(self, var_vals:dict[str, int]):
fxn, varnames = self._sym_fxn
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
return fxn(**{k:v for k,v in var_vals.items() if k in varnames})
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
with Context(TRACK_MATCH_STATS=0):