mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
var_vals uses str for var (#12011)
* var_vals is str,int * remove imports * remove print * fix test * change var_vals in hcq * update test_hcq * fix multitensor _device_num var * fix syminfer test * shorten line * p.vars stays list[Variable] * shorten line * vars is back to tuple[Variable, ...] * change var_vals in extra * change var_vals from shapetracker * var_vals is str:int * fix signature
This commit is contained in:
@@ -10,13 +10,13 @@ from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
||||
class ClangGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
||||
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
||||
args += sorted([f"int {v.expr}" for v in var_vals])
|
||||
args += sorted([f"int {v}" for v in var_vals])
|
||||
code = ["void batched("+','.join(args)+") {"]
|
||||
for ji in jit_cache:
|
||||
args = []
|
||||
@@ -34,6 +34,6 @@ class ClangGraph(GraphRunner):
|
||||
assert compiler is not None
|
||||
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
||||
|
||||
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
|
||||
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False):
|
||||
return cpu_time_execution(
|
||||
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)
|
||||
lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait)
|
||||
|
||||
@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
class HSAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
@@ -53,7 +53,7 @@ class HSAGraph(MultiGraphRunner):
|
||||
self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
|
||||
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16)
|
||||
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
||||
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
||||
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr])
|
||||
|
||||
# Build queues.
|
||||
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
||||
@@ -106,7 +106,7 @@ class HSAGraph(MultiGraphRunner):
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
@@ -123,7 +123,7 @@ class HSAGraph(MultiGraphRunner):
|
||||
# Update var_vals
|
||||
for j in self.jc_idx_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr])
|
||||
|
||||
# Update launch dims
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
|
||||
@@ -88,7 +88,7 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
return ret
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
dev = Device[lin.opts.device]
|
||||
root = MCTSNode(lin)
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_
|
||||
assert dev.compiler is not None
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
p = get_program(lin.get_optimized_ast(), lin.opts)
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
@@ -120,7 +120,7 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
|
||||
if i > 1000:
|
||||
print("WARNING: did not search all possible combinations")
|
||||
break
|
||||
var_vals = {k:v for k,v in zip(vs, ranges)}
|
||||
var_vals = {k.expr:v for k,v in zip(vs, ranges)}
|
||||
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
|
||||
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
|
||||
if r1 != r2: return False
|
||||
|
||||
@@ -52,12 +52,12 @@ class TestHCQ(unittest.TestCase):
|
||||
with self.subTest(name=str(queue_type)):
|
||||
q = queue_type().signal(virt_signal, virt_val)
|
||||
|
||||
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -106,7 +106,7 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestHCQ(unittest.TestCase):
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for _ in range(100):
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
|
||||
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value})
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
@@ -146,7 +146,7 @@ class TestHCQ(unittest.TestCase):
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
|
||||
q.submit(TestHCQ.d0, {sint_global[0].expr: 1, sint_local[0].expr: 1})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestHCQ(unittest.TestCase):
|
||||
for z in range(1, 4):
|
||||
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
|
||||
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value, virt_local[0].expr: x, virt_local[1].expr: y, virt_local[2].expr: z})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -253,7 +253,7 @@ class TestHCQ(unittest.TestCase):
|
||||
.copy(virt_dest_addr, virt_src_addr, 8) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr})
|
||||
q.submit(TestHCQ.d0, {virt_src_addr.expr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr.expr: TestHCQ.b.uop.buffer._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -276,7 +276,7 @@ class TestHCQ(unittest.TestCase):
|
||||
.copy(virt_dest_addr, virt_src_addr, sz) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
|
||||
q.submit(TestHCQ.d0, {virt_src_addr.expr: buf2._buf.va_addr, virt_dest_addr.expr: buf1._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -299,7 +299,7 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
|
||||
4
test/external/fuzz_linearizer.py
vendored
4
test/external/fuzz_linearizer.py
vendored
@@ -90,7 +90,7 @@ def get_fuzz_rawbuf_like(old_rawbuf, zero=False, copy=False, size=None, force_de
|
||||
|
||||
def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> tuple[str, Any]: # (error msg, run state)
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.vars}
|
||||
if var_vals is None: var_vals = {v.expr: v.min for v in lin.vars}
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
try:
|
||||
@@ -129,7 +129,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
|
||||
|
||||
if var_vals is None:
|
||||
# TODO: handle symbolic max case
|
||||
var_vals = {v: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()}
|
||||
var_vals = {v.expr: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()}
|
||||
|
||||
if ground_truth is None and not has_bf16:
|
||||
unoptimized = Kernel(lin.ast)
|
||||
|
||||
@@ -476,7 +476,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
self.assertEqual(len(var_vals), 1)
|
||||
self.assertEqual(list(var_vals)[0], a)
|
||||
self.assertEqual(list(var_vals)[0], a.expr)
|
||||
|
||||
def test_const_factor(self):
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 8),), 'gidx0')
|
||||
|
||||
@@ -73,19 +73,19 @@ class TestSymbolicVarVals(unittest.TestCase):
|
||||
|
||||
def test_var_vals_shape(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
|
||||
assert ShapeTracker.from_shape((x, 3)).var_vals == {"x": 3}
|
||||
|
||||
def test_var_vals_offset(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
|
||||
self.assert_equal(st.views[-1].offset, x * 3)
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
assert st.var_vals == {"x": 3}
|
||||
|
||||
def test_var_vals_mask(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
|
||||
st = ShapeTracker(views=(view,))
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
assert st.var_vals == {"x": 3}
|
||||
|
||||
def test_var_vals_complex(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
@@ -93,13 +93,13 @@ class TestSymbolicVarVals(unittest.TestCase):
|
||||
z = Variable("z", 1, 100).bind(5)
|
||||
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
|
||||
self.assert_equal(st.views[-1].offset, y * z)
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
|
||||
assert st.var_vals == {"x": 3, "y": 4, "z": 5}
|
||||
|
||||
def test_shrink_reshape(self):
|
||||
x = Variable("x", 1, 100).bind(3)
|
||||
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
|
||||
st = st.reshape((3*4*3,))
|
||||
assert st.var_vals == {Variable("x", 1, 100): 3}
|
||||
assert st.var_vals == {"x": 3}
|
||||
|
||||
class TestShapeTrackerUnbind(unittest.TestCase):
|
||||
def test_view_unbind(self):
|
||||
|
||||
@@ -804,7 +804,7 @@ class TestSymInfer(unittest.TestCase):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
var_vals = {a: 2, b: 3, c: 4}
|
||||
var_vals = {a.expr: 2, b.expr: 3, c.expr: 4}
|
||||
assert sym_infer(5, var_vals) == 5
|
||||
assert sym_infer(4.2, var_vals) == 4.2
|
||||
assert sym_infer(a, var_vals) == 2
|
||||
@@ -817,7 +817,7 @@ class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer_cdiv_cmod(self):
|
||||
a = Variable("a", -1000, 1)
|
||||
b = Variable("b", -1000, 1)
|
||||
var_vals = {a: 1, b: -1000}
|
||||
var_vals = {a.expr: 1, b.expr: -1000}
|
||||
assert sym_infer(a%b, var_vals) == 1
|
||||
assert sym_infer(a//b, var_vals) == 0
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
import functools, math, time, multiprocessing, traceback, signal, atexit
|
||||
from dataclasses import replace
|
||||
from tinygrad.uop.ops import Variable, sym_infer, AxisType, pyrender
|
||||
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
||||
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
||||
@@ -34,7 +34,7 @@ def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
break
|
||||
return test_global_size, input_size / prod(test_global_size)
|
||||
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
||||
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
||||
factor = 1
|
||||
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
||||
@@ -141,7 +141,7 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
|
||||
@@ -16,7 +16,7 @@ class GraphException(Exception): pass
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache: list[ExecItem] = []
|
||||
@@ -73,7 +73,7 @@ def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer])
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
@@ -82,7 +82,7 @@ class GraphRunner(Runner):
|
||||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
||||
self.vars = sorted(var_vals.keys())
|
||||
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
||||
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
||||
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
||||
@@ -91,7 +91,7 @@ class GraphRunner(Runner):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
estimates += ji.prg.estimates
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars]
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
|
||||
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
@@ -105,12 +105,12 @@ class GraphRunner(Runner):
|
||||
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
|
||||
def updated_vars(self, var_vals: dict[Variable, int]):
|
||||
def updated_vars(self, var_vals: dict[str, int]):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
for j, vidxs in self.var_vals_replace.items():
|
||||
for i, v in vidxs: yield j, i, vals[v]
|
||||
|
||||
def updated_launch_dims(self, var_vals: dict[Variable, int]):
|
||||
def updated_launch_dims(self, var_vals: dict[str, int]):
|
||||
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
||||
for j, (gl, lc) in self.launch_dims_replace.items():
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
@@ -192,7 +192,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
self.__post_init__()
|
||||
|
||||
# jit exec
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||
# assign inputs
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
@@ -225,7 +225,8 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
@@ -56,9 +56,9 @@ class Runner:
|
||||
self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
|
||||
@property
|
||||
def dev(self): return Device[self.device]
|
||||
def exec(self, rawbufs:list[Buffer], var_vals:dict[Variable, int]|None=None) -> float|None:
|
||||
def exec(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None) -> float|None:
|
||||
return self(rawbufs, {} if var_vals is None else var_vals)
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
||||
@@ -89,7 +89,7 @@ class CompiledRunner(Runner):
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
|
||||
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
||||
@@ -102,11 +102,11 @@ class CompiledRunner(Runner):
|
||||
if local_size:
|
||||
lra['local_size'] = tuple(local_size)
|
||||
assert len(local_size) == 3, "local size must have len 3"
|
||||
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
||||
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
|
||||
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
||||
|
||||
class BufferCopy(Runner):
|
||||
@@ -124,7 +124,7 @@ class BufferCopy(Runner):
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
|
||||
dest, src = rawbufs[0:2]
|
||||
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
|
||||
st = time.perf_counter()
|
||||
@@ -159,8 +159,8 @@ class ExecItem:
|
||||
prg: Runner
|
||||
bufs: list[Buffer|None]
|
||||
metadata: tuple[Metadata, ...]|None = None
|
||||
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
||||
def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", self.prg.display_name, {"metadata":self.metadata, "var_vals":var_vals}))
|
||||
@@ -211,7 +211,7 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem,
|
||||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=None, do_update_stats=True):
|
||||
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
for si, ei in lower_schedule(schedule):
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
|
||||
@@ -229,4 +229,3 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=
|
||||
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Variable, Ops, buffers
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, all_same
|
||||
|
||||
@@ -12,15 +12,15 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[Variable, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
k = u.src[1]
|
||||
@@ -40,8 +40,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
var, val = s.unbind()
|
||||
assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}"
|
||||
var_vals[var] = val
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
|
||||
|
||||
@@ -72,7 +72,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {}))
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
|
||||
@@ -101,7 +101,7 @@ class ProgramSpec:
|
||||
def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \
|
||||
self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None
|
||||
|
||||
def launch_dims(self, var_vals:dict[Variable, int]):
|
||||
def launch_dims(self, var_vals:dict[str, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
@@ -4,12 +4,11 @@ import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, dedup
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
@@ -28,7 +27,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
@@ -48,7 +47,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
@@ -69,7 +69,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
|
||||
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
|
||||
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
|
||||
@@ -183,7 +183,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
@@ -195,12 +195,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
||||
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
||||
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||
|
||||
@@ -5,7 +5,6 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
||||
|
||||
@@ -17,7 +16,7 @@ class MTLResourceUsage:
|
||||
MTLResourceUsageWrite = 0b10
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
@@ -48,7 +47,8 @@ class MetalGraph(GraphRunner):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars):
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
||||
@@ -61,7 +61,7 @@ class MetalGraph(GraphRunner):
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
# NOTE: old command buffer may not be inflight anymore
|
||||
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time, itertools
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
@@ -18,7 +17,7 @@ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_
|
||||
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
||||
|
||||
class RemoteGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
c2d = {device.conn: device for device in devices}
|
||||
@@ -93,7 +92,7 @@ class RemoteGraph(MultiGraphRunner):
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
||||
for req in self.template:
|
||||
|
||||
@@ -100,7 +100,7 @@ class GraphComputeItem:
|
||||
datahash: str
|
||||
bufs: tuple[int, ...]
|
||||
vars: tuple[Variable, ...]
|
||||
fixedvars: dict[Variable, int]
|
||||
fixedvars: dict[str, int]
|
||||
ins: tuple[int, ...]
|
||||
outs: tuple[int, ...]
|
||||
global_size: tuple[sint, ...]|None
|
||||
@@ -111,7 +111,7 @@ class GraphAlloc(RemoteRequest):
|
||||
graph_num: int
|
||||
jit_cache: tuple[GraphComputeItem|Transfer, ...]
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[Variable, int]
|
||||
var_vals: dict[str, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphFree(RemoteRequest):
|
||||
@@ -121,7 +121,7 @@ class GraphFree(RemoteRequest):
|
||||
class GraphExec(RemoteRequest):
|
||||
graph_num: int
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[Variable, int]
|
||||
var_vals: dict[str, int]
|
||||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError: fcntl = None #type:ignore[assignment]
|
||||
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
|
||||
from tinygrad.uop.ops import sym_infer, sint, UOp
|
||||
from tinygrad.runtime.autogen import libc
|
||||
|
||||
class MMIOInterface:
|
||||
@@ -192,7 +192,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
||||
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
||||
|
||||
def _apply_var_vals(self, var_vals:dict[Variable, int]):
|
||||
def _apply_var_vals(self, var_vals:dict[str, int]):
|
||||
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
||||
|
||||
for off, sym_idx in self.q_sints:
|
||||
@@ -205,7 +205,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
|
||||
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
|
||||
|
||||
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
|
||||
def submit(self, dev:HCQDeviceType, var_vals:dict[str, int]|None=None):
|
||||
"""
|
||||
Submits the command queue to a specific device for execution.
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class ShapeTracker:
|
||||
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
||||
|
||||
@property
|
||||
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
||||
def var_vals(self) -> dict[str, int]: return merge_dicts([{(vu:=v.unbind())[0].expr:vu[1]} for v in self.vars()])
|
||||
|
||||
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
|
||||
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -241,7 +241,7 @@ class Tensor(MathTrait):
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||
return self
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
|
||||
def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x)
|
||||
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
# used for UOp and UPat
|
||||
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
||||
@@ -486,7 +486,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self):
|
||||
def expr(self) -> str:
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int|UOp):
|
||||
@@ -584,9 +584,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
|
||||
|
||||
def sym_infer(self, var_vals:dict[UOp, int]):
|
||||
def sym_infer(self, var_vals:dict[str, int]):
|
||||
fxn, varnames = self._sym_fxn
|
||||
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
||||
return fxn(**{k:v for k,v in var_vals.items() if k in varnames})
|
||||
|
||||
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
|
||||
Reference in New Issue
Block a user