mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
factor scheduling into complete_create_schedule_with_vars (#13464)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, all_same
|
||||
from tinygrad.helpers import Metadata, all_same, DEBUG, cpu_profile, TracingKey, SPEC, flatten
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -114,3 +116,35 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
return real_schedule, var_vals
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
# verify Tensors match the spec
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
# tensor map is what we return
|
||||
tensor_map: dict[UOp, UOp] = {}
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
tensor_map |= get_multi_map(big_sink)
|
||||
big_sink = big_sink.substitute(tensor_map, name="Apply Multi Map")
|
||||
big_sink = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink.src]))
|
||||
|
||||
tensor_map |= get_rangeify_map(big_sink)
|
||||
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
# create the schedule
|
||||
with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
return tensor_map, schedule, var_vals
|
||||
|
||||
@@ -6,19 +6,15 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.movement import _align_left
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
# TODO: this should be the only usage of Device
|
||||
def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device)
|
||||
@@ -234,30 +230,9 @@ class Tensor(OpMixin):
|
||||
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
st = time.perf_counter()
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# verify Tensors match the spec
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), name="Apply Multi Map")
|
||||
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
||||
|
||||
becomes_map = get_rangeify_map(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||
|
||||
# get new sink
|
||||
sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.AFTER}
|
||||
_apply_map_to_tensors(remove_assign_map, name="Remove After")
|
||||
|
||||
# create the schedule
|
||||
with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
return schedule, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
|
||||
Reference in New Issue
Block a user