mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
more save_schedule tooling (#5547)
This commit is contained in:
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -90,7 +90,7 @@ class Buffer:
|
||||
if self._base is not None:
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||
if self.is_allocated():
|
||||
if self.is_allocated() and not SAVE_SCHEDULE:
|
||||
buf = bytearray(self.nbytes)
|
||||
self.copyout(memoryview(buf))
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
||||
|
||||
@@ -337,6 +337,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
||||
if len(SCHEDULES) == 0: atexit.register(_save)
|
||||
SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
|
||||
if SAVE_SCHEDULE.value == len(SCHEDULES): exit(0)
|
||||
# confirm everything was scheduled correctly
|
||||
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
|
||||
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
|
||||
|
||||
Reference in New Issue
Block a user