diff --git a/tinygrad/device.py b/tinygrad/device.py index d7f2c89a22..2f76efeee8 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from collections import defaultdict from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array -from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE +from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE from tinygrad.dtype import DType, ImageDType from tinygrad.renderer import Renderer @@ -90,7 +90,7 @@ class Buffer: if self._base is not None: return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf')) if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount) - if self.is_allocated(): + if self.is_allocated() and not SAVE_SCHEDULE: buf = bytearray(self.nbytes) self.copyout(memoryview(buf)) return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6e41d4a8ee..0a0478fdc2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -337,6 +337,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe with open(fp, "wb") as f: pickle.dump(SCHEDULES, f) if len(SCHEDULES) == 0: atexit.register(_save) SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)]) + if SAVE_SCHEDULE.value == len(SCHEDULES): exit(0) # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")