mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
more
This commit is contained in:
4
test/external/external_benchmark_resnet.py
vendored
4
test/external/external_benchmark_resnet.py
vendored
@@ -71,8 +71,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
||||
|
||||
y = x.sequential(layer).contiguous().contiguous_backward()
|
||||
y.sum().backward()
|
||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
||||
else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
||||
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
||||
else: sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
||||
|
||||
for _ in range(JITCNT):
|
||||
run_schedule(list(sched))
|
||||
|
||||
@@ -49,8 +49,8 @@ class BenchmarkBertTrain(unittest.TestCase):
|
||||
|
||||
y = layer(*inputs).contiguous().contiguous_backward()
|
||||
y.sum().backward()
|
||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
||||
else: sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
||||
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
||||
else: sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
||||
|
||||
for _ in range(JITCNT):
|
||||
run_schedule(sched)
|
||||
|
||||
@@ -865,7 +865,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
for src in ast.src:
|
||||
if (ret:=self._find_op(src, op)) is not None: return ret
|
||||
def _schedule_render(self, a: Tensor):
|
||||
schedule, _ = a.schedule_with_vars()
|
||||
schedule, _buffer_map, _var_vals = a.schedule_with_vars()
|
||||
for s in schedule:
|
||||
if s.ast.op is Ops.SINK:
|
||||
renderer = Device[s.bufs[0].device].renderer
|
||||
|
||||
@@ -481,7 +481,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = Tensor(a.bind(1))
|
||||
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
_, _, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
self.assertEqual(len(var_vals), 1)
|
||||
self.assertEqual(list(var_vals)[0], a.expr)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class TestRingAllReduce(unittest.TestCase):
|
||||
N = 4
|
||||
ds = tuple(f"CPU:{i}" for i in range(N))
|
||||
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
|
||||
schedules = t.sum(0).schedule_with_vars()[0]
|
||||
schedules, _, _ = t.sum(0).schedule_with_vars()
|
||||
copies = [si for si in schedules if si.ast.op is Ops.COPY]
|
||||
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
|
||||
# N*(N-1) scatter reduce, and N*(N-1) allgather
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestScheduleCache(unittest.TestCase):
|
||||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
t = x + Tensor(v.bind(42))
|
||||
_, var_vals = t.schedule_with_vars()
|
||||
_, _, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
def test_simple(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, dedup, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
|
||||
from tinygrad.uop.ops import UOp, Ops, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
|
||||
@@ -37,15 +37,14 @@ class ExecutionUnit:
|
||||
|
||||
self._bound_items = []
|
||||
for item in self.items:
|
||||
# Get buffers - either from buffer_map (for UOps) or directly (for already-bound Buffers)
|
||||
# Get buffers - prefer buf_uops with buffer_map, fall back to bufs for backwards compatibility
|
||||
bufs: list[Buffer] = []
|
||||
for b in item.bufs:
|
||||
if b is None:
|
||||
continue
|
||||
if isinstance(b, UOp):
|
||||
bufs.append(self.buffer_map[b])
|
||||
else:
|
||||
bufs.append(b)
|
||||
if item.buf_uops:
|
||||
for uop in item.buf_uops:
|
||||
bufs.append(cast(Buffer, self.buffer_map.get(uop) or uop.buffer))
|
||||
else:
|
||||
for buf in item.bufs:
|
||||
if buf is not None: bufs.append(buf)
|
||||
|
||||
# Create runner from lib or use existing prg
|
||||
if item.prg is not None:
|
||||
|
||||
@@ -63,8 +63,18 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
|
||||
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
def memory_planner(schedule:list[ExecItem]) -> tuple[list[ExecItem], dict[UOp, Buffer]]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
||||
new_schedule = [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars, buf_uops=si.buf_uops)
|
||||
for si in schedule]
|
||||
# Build buffer_map from buf_uops -> assigned buffers
|
||||
buffer_map: dict[UOp, Buffer] = {}
|
||||
for si in new_schedule:
|
||||
for i, uop in enumerate(si.buf_uops):
|
||||
if uop not in buffer_map and i < len(si.bufs) and si.bufs[i] is not None:
|
||||
buffer_map[uop] = si.bufs[i] # type: ignore
|
||||
return new_schedule, buffer_map
|
||||
|
||||
@@ -183,11 +183,12 @@ si_lowerer = PatternMatcher([
|
||||
@dataclass
|
||||
class ExecItem:
|
||||
ast: UOp
|
||||
bufs: list[Buffer|None] = field(default_factory=list)
|
||||
bufs: list[Buffer|None] = field(default_factory=list) # TODO: deprecate, use buf_uops + buffer_map
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
lib: bytes|None = None # compiled binary, None for COPY/VIEW/ENCDEC
|
||||
prg: Runner|None = None
|
||||
buf_uops: tuple[UOp, ...] = () # buffer UOps, binding happens in ExecutionUnit
|
||||
|
||||
def lower(self):
|
||||
"""Populate self.prg and self.lib by lowering the AST."""
|
||||
@@ -242,8 +243,9 @@ class ExecItem:
|
||||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
def run_schedule(schedule:list[ExecItem], buffer_map:dict[UOp, Buffer]|None=None, var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
from tinygrad.engine.execution import ExecutionUnit
|
||||
if buffer_map is None: buffer_map = {}
|
||||
|
||||
# Lower all items first
|
||||
lowered: list[ExecItem] = []
|
||||
@@ -263,7 +265,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
|
||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
|
||||
# run on GPU
|
||||
ExecutionUnit([ei]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0):
|
||||
@@ -272,8 +274,8 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ExecutionUnit([ei]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
else:
|
||||
# Use ExecutionUnit for batched execution
|
||||
if lowered:
|
||||
ExecutionUnit(lowered).update(var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
ExecutionUnit(lowered).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
|
||||
@@ -125,7 +125,7 @@ pm_post_sched_cache = PatternMatcher([
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -185,11 +185,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}), buf_uops=buf_uops))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars, buf_uops=buf_uops))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule, buffer_map = memory_planner(schedule)
|
||||
|
||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||
var_vals: dict[str, int] = {}
|
||||
@@ -204,4 +204,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
return tensor_map, schedule, var_vals
|
||||
return tensor_map, schedule, buffer_map, var_vals
|
||||
|
||||
@@ -237,7 +237,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
@@ -246,13 +246,13 @@ class Tensor(OpMixin):
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# this is where the schedule cache should go
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
becomes_map, schedule, buffer_map, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
return schedule, var_vals
|
||||
return schedule, buffer_map, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||
schedule, _buffer_map, var_vals = self.schedule_with_vars(*lst)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
@@ -260,7 +260,8 @@ class Tensor(OpMixin):
|
||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
schedule, buffer_map, var_vals = Tensor.schedule_with_vars(*to_realize)
|
||||
run_schedule(schedule, buffer_map, var_vals, do_update_stats=do_update_stats)
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user