This commit is contained in:
George Hotz
2025-12-21 15:48:59 -04:00
parent 5515de6553
commit 76f2f14233
11 changed files with 47 additions and 35 deletions

View File

@@ -71,8 +71,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
y = x.sequential(layer).contiguous().contiguous_backward()
y.sum().backward()
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
else: sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
for _ in range(JITCNT):
run_schedule(list(sched))

View File

@@ -49,8 +49,8 @@ class BenchmarkBertTrain(unittest.TestCase):
y = layer(*inputs).contiguous().contiguous_backward()
y.sum().backward()
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
else: sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
else: sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
for _ in range(JITCNT):
run_schedule(sched)

View File

@@ -865,7 +865,7 @@ class TestIdxUpcast(unittest.TestCase):
for src in ast.src:
if (ret:=self._find_op(src, op)) is not None: return ret
def _schedule_render(self, a: Tensor):
schedule, _ = a.schedule_with_vars()
schedule, _buffer_map, _var_vals = a.schedule_with_vars()
for s in schedule:
if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer

View File

@@ -481,7 +481,7 @@ class TestUOpMethod(unittest.TestCase):
a = UOp.variable("a", 1, 10)
uop_var = Tensor(a.bind(1))
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
_, var_vals = (uop_var+st_var).schedule_with_vars()
_, _, var_vals = (uop_var+st_var).schedule_with_vars()
self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a.expr)

View File

@@ -9,7 +9,7 @@ class TestRingAllReduce(unittest.TestCase):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
schedules = t.sum(0).schedule_with_vars()[0]
schedules, _, _ = t.sum(0).schedule_with_vars()
copies = [si for si in schedules if si.ast.op is Ops.COPY]
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather

View File

@@ -23,7 +23,7 @@ class TestScheduleCache(unittest.TestCase):
x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars()
_, _, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
def test_simple(self):

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, dedup, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
from typing import Any, cast
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
from tinygrad.uop.ops import UOp, Ops, sym_infer
from tinygrad.device import Device, Buffer
@@ -37,15 +37,14 @@ class ExecutionUnit:
self._bound_items = []
for item in self.items:
# Get buffers - either from buffer_map (for UOps) or directly (for already-bound Buffers)
# Get buffers - prefer buf_uops with buffer_map, fall back to bufs for backwards compatibility
bufs: list[Buffer] = []
for b in item.bufs:
if b is None:
continue
if isinstance(b, UOp):
bufs.append(self.buffer_map[b])
else:
bufs.append(b)
if item.buf_uops:
for uop in item.buf_uops:
bufs.append(cast(Buffer, self.buffer_map.get(uop) or uop.buffer))
else:
for buf in item.bufs:
if buf is not None: bufs.append(buf)
# Create runner from lib or use existing prg
if item.prg is not None:

View File

@@ -63,8 +63,18 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
return assigned
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
from tinygrad.uop.ops import UOp
def memory_planner(schedule:list[ExecItem]) -> tuple[list[ExecItem], dict[UOp, Buffer]]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
new_schedule = [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars, buf_uops=si.buf_uops)
for si in schedule]
# Build buffer_map from buf_uops -> assigned buffers
buffer_map: dict[UOp, Buffer] = {}
for si in new_schedule:
for i, uop in enumerate(si.buf_uops):
if uop not in buffer_map and i < len(si.bufs) and si.bufs[i] is not None:
buffer_map[uop] = si.bufs[i] # type: ignore
return new_schedule, buffer_map

View File

@@ -183,11 +183,12 @@ si_lowerer = PatternMatcher([
@dataclass
class ExecItem:
ast: UOp
bufs: list[Buffer|None] = field(default_factory=list)
bufs: list[Buffer|None] = field(default_factory=list) # TODO: deprecate, use buf_uops + buffer_map
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict)
lib: bytes|None = None # compiled binary, None for COPY/VIEW/ENCDEC
prg: Runner|None = None
buf_uops: tuple[UOp, ...] = () # buffer UOps, binding happens in ExecutionUnit
def lower(self):
"""Populate self.prg and self.lib by lowering the AST."""
@@ -242,8 +243,9 @@ class ExecItem:
capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
def run_schedule(schedule:list[ExecItem], buffer_map:dict[UOp, Buffer]|None=None, var_vals:dict[str, int]|None=None, do_update_stats=True):
from tinygrad.engine.execution import ExecutionUnit
if buffer_map is None: buffer_map = {}
# Lower all items first
lowered: list[ExecItem] = []
@@ -263,7 +265,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
# run on GPU
ExecutionUnit([ei]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
with Context(BEAM=0):
@@ -272,8 +274,8 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
assert nb[0] is not None
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else:
ExecutionUnit([ei]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
else:
# Use ExecutionUnit for batched execution
if lowered:
ExecutionUnit(lowered).update(var_vals=var_vals)(do_update_stats=do_update_stats)
ExecutionUnit(lowered).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)

View File

@@ -125,7 +125,7 @@ pm_post_sched_cache = PatternMatcher([
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
@@ -185,11 +185,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}), buf_uops=buf_uops))
else:
# ONE -> ONE
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars, buf_uops=buf_uops))
with cpu_profile(TracingKey("memory planner")): schedule, buffer_map = memory_planner(schedule)
# extract var_vals from BINDs that were stripped (only if there are kernels)
var_vals: dict[str, int] = {}
@@ -204,4 +204,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals
return tensor_map, schedule, buffer_map, var_vals

View File

@@ -237,7 +237,7 @@ class Tensor(OpMixin):
"""
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
@@ -246,13 +246,13 @@ class Tensor(OpMixin):
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
becomes_map, schedule, buffer_map, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
return schedule, var_vals
return schedule, buffer_map, var_vals
def schedule(self, *lst:Tensor) -> list[ExecItem]:
"""Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst)
schedule, _buffer_map, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0
return schedule
@@ -260,7 +260,8 @@ class Tensor(OpMixin):
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
schedule, buffer_map, var_vals = Tensor.schedule_with_vars(*to_realize)
run_schedule(schedule, buffer_map, var_vals, do_update_stats=do_update_stats)
return self
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor: