diff --git a/extra/introspection.py b/extra/introspection.py index f379191001..f3e301aacd 100644 --- a/extra/introspection.py +++ b/extra/introspection.py @@ -2,7 +2,7 @@ import gc from tinygrad.helpers import prod from tinygrad.lazy import LazyBuffer -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from tinygrad import Tensor, GlobalCounters def print_objects(): diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 835b6f5e8a..8a7a730ac7 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -14,9 +14,8 @@ import onnx from typing import Tuple, List, Optional, Dict, cast from extra.onnx import get_run_onnx from tinygrad import Tensor, Device, GlobalCounters, dtypes -from tinygrad.buffer import Buffer from tinygrad.dtype import ImageDType -from tinygrad.device import CompiledRunner +from tinygrad.device import CompiledRunner, Buffer from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem from tinygrad.engine.memory import memory_planner diff --git a/test/external/external_test_hcq.py b/test/external/external_test_hcq.py index 2ac39d5726..9677bcd9c2 100644 --- a/test/external/external_test_hcq.py +++ b/test/external/external_test_hcq.py @@ -1,7 +1,7 @@ import unittest, ctypes, struct, time, array from tinygrad import Device, Tensor, dtypes from tinygrad.helpers import to_mv -from tinygrad.buffer import Buffer, BufferOptions +from tinygrad.device import Buffer, BufferOptions from tinygrad.engine.schedule import create_schedule def _time_queue(q, d): diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index eb4eaaef66..eb946655b1 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -1,7 +1,7 @@ import itertools import numpy as np from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.lazy import LazyBuffer diff --git a/test/test_subbuffer.py b/test/test_subbuffer.py index 85a55384f6..cd0f42ef39 100644 --- a/test/test_subbuffer.py +++ b/test/test_subbuffer.py @@ -1,7 +1,7 @@ import unittest from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from tinygrad.lazy import view_supported_devices @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") diff --git a/tinygrad/buffer.py b/tinygrad/buffer.py deleted file mode 100644 index fc9d739ed4..0000000000 --- a/tinygrad/buffer.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations -from typing import Any, Optional, Dict, Tuple -import ctypes -from collections import defaultdict -from dataclasses import dataclass -from tinygrad.helpers import GlobalCounters, flat_mv, from_mv, getenv -from tinygrad.dtype import DType, ImageDType - -@dataclass(frozen=True, eq=True) -class BufferOptions: - image: Optional[ImageDType] = None - uncached: bool = False - cpu_access: bool = False - host: bool = False - nolru: bool = False - -class Buffer: - def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, - initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False): - assert isinstance(dtype, DType) - if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be? - self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset - if base is None: - assert offset == 0, "base buffers can't have offset" - self._base = None - self._lb_refcount = lb_refcount - if opaque is not None: self.allocate(opaque) - if initial_value is not None: - self.allocate() - self.copyin(memoryview(initial_value)) - else: - assert base._base is None, "base can't have a base" - assert device == base.device, "base must have the same device" - self._base = base - if preallocate: self.allocate() - @property - def base(self) -> Buffer: return self._base if self._base is not None else self - @property - def lb_refcount(self): return self.base._lb_refcount - def ref(self, cnt): self.base._lb_refcount += cnt - def is_allocated(self) -> bool: return hasattr(self, '_buf') - def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self - def allocate(self, opaque=None) -> Buffer: - assert not hasattr(self, '_buf'), "can't allocate already allocated buffer" - from tinygrad.device import Device - self.allocator = Device[self.device].allocator - if self._base is not None: - self._base.ensure_allocated() - assert hasattr(self.allocator, "offset"), "offset function required for view" - self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset) - else: - self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options) - if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes - return self - def __reduce__(self): - buf = None - if self._base is not None: - return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf')) - if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount) - if self.is_allocated(): - buf = bytearray(self.nbytes) - self.copyout(memoryview(buf)) - return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount) - @property - def nbytes(self): return self.size*self.dtype.itemsize - def __del__(self): - if not hasattr(self, '_buf'): return - if self._base is None: - if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes - self.allocator.free(self._buf, self.nbytes, self.options) - def __repr__(self): - return f"" if self.options is None else f" {self.options=}>") - def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: - # zero copy with as_buffer (disabled by default due to use after free) - if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf) - assert not force_zero_copy, "force zero copy was passed, but copy is required" - return self.copyout(memoryview(bytearray(self.nbytes))) - def copyin(self, mv:memoryview): - mv = flat_mv(mv) - assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" - assert self.is_allocated(), "can't copyin to unallocated buffer" - self.allocator.copyin(self._buf, mv) - return self - def copyout(self, mv:memoryview) -> memoryview: - mv = flat_mv(mv) - assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" - assert self.is_allocated(), "can't copyout unallocated buffer" - self.allocator.copyout(mv, self._buf) - return mv - def view(self, size:int, dtype:DType, offset:int) -> Buffer: - assert offset < self.nbytes, "offset must be less than nbytes" - if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset) - return Buffer(self.device, size, dtype, base=self, offset=offset) - -# TODO: size, dest, src are the same type. can we enforce this? -class Allocator: - def alloc(self, size:int, options:Optional[BufferOptions]=None): - assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}" - return self._alloc(size, options if options is not None else BufferOptions()) - def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc") - def free(self, opaque, size:int, options:Optional[BufferOptions]=None): - self._free(opaque, options if options is not None else BufferOptions()) - def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free - def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin") - def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout") - -class LRUAllocator(Allocator): # pylint: disable=abstract-method - def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list) - def alloc(self, size:int, options:Optional[BufferOptions]=None): - if len(c := self.cache[(size, options)]): return c.pop() - try: return super().alloc(size, options) - except (RuntimeError, MemoryError): - self.free_cache() - return super().alloc(size, options) - def free_cache(self): - for (sz,options),opaques in self.cache.items(): - for opaque in opaques: super().free(opaque, sz, options) - opaques.clear() - def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None): - if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) - else: super().free(opaque, size, options) - -class _MallocAllocator(LRUAllocator): - def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)() - def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) - def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src)) - def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) - def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size]) - -MallocAllocator = _MallocAllocator() diff --git a/tinygrad/device.py b/tinygrad/device.py index cc355199ed..d59765e61e 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,12 +1,13 @@ from __future__ import annotations import multiprocessing from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable -import importlib, inspect, functools, pathlib, os -from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT +from collections import defaultdict +from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable, Any +import importlib, inspect, functools, pathlib, os, ctypes +from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv from tinygrad.shape.symbolic import Variable, sym_infer, sint +from tinygrad.dtype import DType, ImageDType from tinygrad.ops import LazyOp, get_lazyop_info -from tinygrad.buffer import Buffer, Allocator from tinygrad.codegen.uops import UOpGraph if TYPE_CHECKING: @@ -40,6 +41,132 @@ class _Device: raise RuntimeError("no usable devices") Device = _Device() +# **************** Buffer + Allocators **************** + +@dataclass(frozen=True, eq=True) +class BufferOptions: + image: Optional[ImageDType] = None + uncached: bool = False + cpu_access: bool = False + host: bool = False + nolru: bool = False + +class Buffer: + def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, + initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False): + assert isinstance(dtype, DType) + if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be? + self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset + if base is None: + assert offset == 0, "base buffers can't have offset" + self._base = None + self._lb_refcount = lb_refcount + if opaque is not None: self.allocate(opaque) + if initial_value is not None: + self.allocate() + self.copyin(memoryview(initial_value)) + else: + assert base._base is None, "base can't have a base" + assert device == base.device, "base must have the same device" + self._base = base + if preallocate: self.allocate() + @property + def base(self) -> Buffer: return self._base if self._base is not None else self + @property + def lb_refcount(self): return self.base._lb_refcount + def ref(self, cnt): self.base._lb_refcount += cnt + def is_allocated(self) -> bool: return hasattr(self, '_buf') + def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self + def allocate(self, opaque=None) -> Buffer: + assert not hasattr(self, '_buf'), "can't allocate already allocated buffer" + self.allocator = Device[self.device].allocator + if self._base is not None: + self._base.ensure_allocated() + assert hasattr(self.allocator, "offset"), "offset function required for view" + self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset) + else: + self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options) + if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes + return self + def __reduce__(self): + buf = None + if self._base is not None: + return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf')) + if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount) + if self.is_allocated(): + buf = bytearray(self.nbytes) + self.copyout(memoryview(buf)) + return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount) + @property + def nbytes(self): return self.size*self.dtype.itemsize + def __del__(self): + if not hasattr(self, '_buf'): return + if self._base is None: + if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes + self.allocator.free(self._buf, self.nbytes, self.options) + def __repr__(self): + return f"" if self.options is None else f" {self.options=}>") + def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: + # zero copy with as_buffer (disabled by default due to use after free) + if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf) + assert not force_zero_copy, "force zero copy was passed, but copy is required" + return self.copyout(memoryview(bytearray(self.nbytes))) + def copyin(self, mv:memoryview): + mv = flat_mv(mv) + assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" + assert self.is_allocated(), "can't copyin to unallocated buffer" + self.allocator.copyin(self._buf, mv) + return self + def copyout(self, mv:memoryview) -> memoryview: + mv = flat_mv(mv) + assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" + assert self.is_allocated(), "can't copyout unallocated buffer" + self.allocator.copyout(mv, self._buf) + return mv + def view(self, size:int, dtype:DType, offset:int) -> Buffer: + assert offset < self.nbytes, "offset must be less than nbytes" + if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset) + return Buffer(self.device, size, dtype, base=self, offset=offset) + +# TODO: size, dest, src are the same type. can we enforce this? +class Allocator: + def alloc(self, size:int, options:Optional[BufferOptions]=None): + assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}" + return self._alloc(size, options if options is not None else BufferOptions()) + def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc") + def free(self, opaque, size:int, options:Optional[BufferOptions]=None): + self._free(opaque, options if options is not None else BufferOptions()) + def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free + def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin") + def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout") + +class LRUAllocator(Allocator): # pylint: disable=abstract-method + def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list) + def alloc(self, size:int, options:Optional[BufferOptions]=None): + if len(c := self.cache[(size, options)]): return c.pop() + try: return super().alloc(size, options) + except (RuntimeError, MemoryError): + self.free_cache() + return super().alloc(size, options) + def free_cache(self): + for (sz,options),opaques in self.cache.items(): + for opaque in opaques: super().free(opaque, sz, options) + opaques.clear() + def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None): + if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque) + else: super().free(opaque, size, options) + +class _MallocAllocator(LRUAllocator): + def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)() + def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) + def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src)) + def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) + def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size]) + +MallocAllocator = _MallocAllocator() + # **************** base Runner + helpers **************** class Runner: diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 9ebcf0961e..aa7dc1663f 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -1,7 +1,7 @@ from typing import List, Dict, DefaultDict, Tuple, Union from collections import defaultdict from tinygrad.dtype import DType -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from tinygrad.helpers import getenv, DEBUG, dedup from tinygrad.ops import ScheduleItem diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 5688891c48..e22b077eb4 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast, LazyOp from tinygrad.device import Runner, Device -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from tinygrad.shape.symbolic import Variable, sym_infer # **************** Runners **************** diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 28100ec7f4..63fb0d93b6 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -6,7 +6,7 @@ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.buffer import Buffer +from tinygrad.device import Buffer from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0cb80538dc..3b1c8b06bd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Union, Type, Tuple, Any, List, Dict, Callable +from typing import Union, Type, Tuple, Any, List, Dict, Callable, TYPE_CHECKING import functools, hashlib, math, operator, ctypes from enum import Enum, auto from dataclasses import dataclass @@ -7,7 +7,8 @@ from tinygrad.helpers import prod, dedup from tinygrad.dtype import dtypes, DType, ConstType from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.buffer import Buffer +if TYPE_CHECKING: + from tinygrad.device import Buffer # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 6bf6f49790..73420cbce2 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,7 +1,7 @@ import ctypes, collections, array, time from typing import List, Any, Dict, cast, Optional, Tuple, Set from tinygrad.helpers import GraphException, round_up, to_mv -from tinygrad.buffer import Buffer, BufferOptions +from tinygrad.device import Buffer, BufferOptions from tinygrad.device import Compiled, CompiledRunner, Device from tinygrad.shape.symbolic import Variable from tinygrad.engine.realize import ExecItem, BufferXfer diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 5da2992ba4..d96076eabb 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -1,7 +1,7 @@ import ctypes, collections, time, itertools from typing import List, Any, Dict, cast, Optional, Tuple from tinygrad.helpers import GraphException, init_c_var, round_up -from tinygrad.buffer import Buffer, BufferOptions +from tinygrad.device import Buffer, BufferOptions from tinygrad.device import Compiled, CompiledRunner, Device from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 7e29853dc9..5fe684fa93 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,8 +1,7 @@ from __future__ import annotations from typing import Tuple, List, Any, cast import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import BufferOptions, LRUAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG from tinygrad.renderer.cstyle import HIPRenderer from tinygrad.runtime.driver.hip_comgr import compile_hip diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index a25f5ab7ee..8dead3d4fc 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,6 +1,5 @@ import ctypes, subprocess, pathlib, tempfile -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import MallocAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator from tinygrad.helpers import cpu_time_execution from tinygrad.renderer.cstyle import ClangRenderer diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index c48f58b752..61d983be53 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -5,8 +5,7 @@ from dataclasses import replace from typing import Tuple, Optional, List import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import BufferOptions, LRUAllocator, MallocAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator, MallocAllocator from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.assembly import PTXRenderer if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 7493fd4996..99899e8eea 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -2,8 +2,7 @@ from __future__ import annotations import os, mmap, _posixshmem, io from typing import Optional from tinygrad.helpers import OSX -from tinygrad.device import Compiled -from tinygrad.buffer import Allocator +from tinygrad.device import Compiled, Allocator class DiskBuffer: def __init__(self, device:DiskDevice, size:int, offset=0): diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 534bfe9938..ffa623e244 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -4,8 +4,7 @@ import ctypes, functools, hashlib import tinygrad.runtime.autogen.opencl as cl from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG from tinygrad.renderer.cstyle import OpenCLRenderer -from tinygrad.buffer import BufferOptions, LRUAllocator -from tinygrad.device import Compiled, Compiler, CompilerOptions +from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompilerOptions # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something OSX_TIMING_RATIO = (125/3) if OSX else 1.0 diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index 1425fec92e..eab99f4b50 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -3,8 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json from typing import Tuple, TypeVar, List, Dict, Any import tinygrad.runtime.autogen.hsa as hsa from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import BufferOptions, LRUAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator from tinygrad.renderer.cstyle import HIPRenderer from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue from tinygrad.runtime.driver.hip_comgr import compile_hip diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index df2d69b856..2eb86b3e13 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,8 +1,7 @@ from __future__ import annotations import ctypes, functools from typing import Tuple -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import MallocAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator from tinygrad.helpers import DEBUG, cpu_time_execution from tinygrad.renderer.llvmir import uops_to_llvm_ir import llvmlite.binding as llvm diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 13a588817e..c74639aa70 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -3,8 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile, functools import Metal, libdispatch from typing import List, Set, Any, Tuple, Optional from tinygrad.helpers import prod, getenv, DEBUG, unwrap2 -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import LRUAllocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator from tinygrad.renderer.cstyle import MetalRenderer def wait_check(cbuf: Any): diff --git a/tinygrad/runtime/ops_npy.py b/tinygrad/runtime/ops_npy.py index 3470d0edd2..c8121b9a09 100644 --- a/tinygrad/runtime/ops_npy.py +++ b/tinygrad/runtime/ops_npy.py @@ -1,7 +1,6 @@ import numpy as np from tinygrad.helpers import flat_mv -from tinygrad.device import Compiled -from tinygrad.buffer import Allocator +from tinygrad.device import Compiled, Allocator class NpyAllocator(Allocator): def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index ee041feb31..ad01c4ee87 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -2,8 +2,7 @@ from __future__ import annotations import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array from typing import Tuple, List, Any, cast from dataclasses import replace -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import LRUAllocator, BufferOptions +from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator, BufferOptions from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index e94fd22c5a..6ca539b72b 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -5,8 +5,7 @@ from typing import Tuple, List, Optional, Any, Dict import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import all_same, getenv, flatten -from tinygrad.device import Compiled, Compiler, CompilerOptions -from tinygrad.buffer import Allocator +from tinygrad.device import Compiled, Compiler, CompilerOptions, Allocator from tinygrad.codegen.uops import UOpGraph, UOps from tinygrad.ops import BinaryOps, TernaryOps, exec_alu diff --git a/tinygrad/runtime/ops_rhip.py b/tinygrad/runtime/ops_rhip.py index 074509344f..7bf252ba59 100644 --- a/tinygrad/runtime/ops_rhip.py +++ b/tinygrad/runtime/ops_rhip.py @@ -1,6 +1,5 @@ import ctypes -from tinygrad.device import Compiled -from tinygrad.buffer import MallocAllocator +from tinygrad.device import Compiled, MallocAllocator from tinygrad.runtime.ops_hsa import HSACompiler rhip = ctypes.CDLL("/usr/local/lib/libremu.so") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b64792e918..9b61952309 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -12,7 +12,7 @@ from tinygrad.helpers import getenv from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps, ScheduleItem -from tinygrad.buffer import Buffer, BufferOptions +from tinygrad.device import Buffer, BufferOptions from tinygrad.device import Device from tinygrad.shape.symbolic import sint, Variable, MulNode, Node from tinygrad.engine.realize import run_schedule