mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
bring buffer back to device (#4517)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import gc
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
|
||||
def print_objects():
|
||||
|
||||
@@ -14,9 +14,8 @@ import onnx
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.device import CompiledRunner
|
||||
from tinygrad.device import CompiledRunner, Buffer
|
||||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
|
||||
2
test/external/external_test_hcq.py
vendored
2
test/external/external_test_hcq.py
vendored
@@ -1,7 +1,7 @@
|
||||
import unittest, ctypes, struct, time, array
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import to_mv
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
def _time_queue(q, d):
|
||||
|
||||
2
test/external/fuzz_schedule.py
vendored
2
test/external/fuzz_schedule.py
vendored
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.lazy import view_supported_devices
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
import ctypes
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import GlobalCounters, flat_mv, from_mv, getenv
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class BufferOptions:
|
||||
image: Optional[ImageDType] = None
|
||||
uncached: bool = False
|
||||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
nolru: bool = False
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
||||
assert isinstance(dtype, DType)
|
||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
||||
if base is None:
|
||||
assert offset == 0, "base buffers can't have offset"
|
||||
self._base = None
|
||||
self._lb_refcount = lb_refcount
|
||||
if opaque is not None: self.allocate(opaque)
|
||||
if initial_value is not None:
|
||||
self.allocate()
|
||||
self.copyin(memoryview(initial_value))
|
||||
else:
|
||||
assert base._base is None, "base can't have a base"
|
||||
assert device == base.device, "base must have the same device"
|
||||
self._base = base
|
||||
if preallocate: self.allocate()
|
||||
@property
|
||||
def base(self) -> Buffer: return self._base if self._base is not None else self
|
||||
@property
|
||||
def lb_refcount(self): return self.base._lb_refcount
|
||||
def ref(self, cnt): self.base._lb_refcount += cnt
|
||||
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||
def allocate(self, opaque=None) -> Buffer:
|
||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||
from tinygrad.device import Device
|
||||
self.allocator = Device[self.device].allocator
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
|
||||
else:
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
return self
|
||||
def __reduce__(self):
|
||||
buf = None
|
||||
if self._base is not None:
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||
if self.is_allocated():
|
||||
buf = bytearray(self.nbytes)
|
||||
self.copyout(memoryview(buf))
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
||||
@property
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if not hasattr(self, '_buf'): return
|
||||
if self._base is None:
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
def __repr__(self):
|
||||
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
||||
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
|
||||
(">" if self.options is None else f" {self.options=}>")
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
||||
self.allocator.copyin(self._buf, mv)
|
||||
return self
|
||||
def copyout(self, mv:memoryview) -> memoryview:
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyout unallocated buffer"
|
||||
self.allocator.copyout(mv, self._buf)
|
||||
return mv
|
||||
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
||||
assert offset < self.nbytes, "offset must be less than nbytes"
|
||||
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
||||
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
|
||||
return self._alloc(size, options if options is not None else BufferOptions())
|
||||
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
||||
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
|
||||
self._free(opaque, options if options is not None else BufferOptions())
|
||||
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
||||
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
||||
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
||||
|
||||
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
if len(c := self.cache[(size, options)]): return c.pop()
|
||||
try: return super().alloc(size, options)
|
||||
except (RuntimeError, MemoryError):
|
||||
self.free_cache()
|
||||
return super().alloc(size, options)
|
||||
def free_cache(self):
|
||||
for (sz,options),opaques in self.cache.items():
|
||||
for opaque in opaques: super().free(opaque, sz, options)
|
||||
opaques.clear()
|
||||
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
|
||||
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
|
||||
|
||||
MallocAllocator = _MallocAllocator()
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable
|
||||
import importlib, inspect, functools, pathlib, os
|
||||
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable, Any
|
||||
import importlib, inspect, functools, pathlib, os, ctypes
|
||||
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.ops import LazyOp, get_lazyop_info
|
||||
from tinygrad.buffer import Buffer, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,6 +41,132 @@ class _Device:
|
||||
raise RuntimeError("no usable devices")
|
||||
Device = _Device()
|
||||
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class BufferOptions:
|
||||
image: Optional[ImageDType] = None
|
||||
uncached: bool = False
|
||||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
nolru: bool = False
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
|
||||
assert isinstance(dtype, DType)
|
||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
|
||||
if base is None:
|
||||
assert offset == 0, "base buffers can't have offset"
|
||||
self._base = None
|
||||
self._lb_refcount = lb_refcount
|
||||
if opaque is not None: self.allocate(opaque)
|
||||
if initial_value is not None:
|
||||
self.allocate()
|
||||
self.copyin(memoryview(initial_value))
|
||||
else:
|
||||
assert base._base is None, "base can't have a base"
|
||||
assert device == base.device, "base must have the same device"
|
||||
self._base = base
|
||||
if preallocate: self.allocate()
|
||||
@property
|
||||
def base(self) -> Buffer: return self._base if self._base is not None else self
|
||||
@property
|
||||
def lb_refcount(self): return self.base._lb_refcount
|
||||
def ref(self, cnt): self.base._lb_refcount += cnt
|
||||
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||
def allocate(self, opaque=None) -> Buffer:
|
||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||
self.allocator = Device[self.device].allocator
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
|
||||
else:
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
return self
|
||||
def __reduce__(self):
|
||||
buf = None
|
||||
if self._base is not None:
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||
if self.is_allocated():
|
||||
buf = bytearray(self.nbytes)
|
||||
self.copyout(memoryview(buf))
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
||||
@property
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if not hasattr(self, '_buf'): return
|
||||
if self._base is None:
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
def __repr__(self):
|
||||
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
|
||||
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
|
||||
(">" if self.options is None else f" {self.options=}>")
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
def copyin(self, mv:memoryview):
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyin to unallocated buffer"
|
||||
self.allocator.copyin(self._buf, mv)
|
||||
return self
|
||||
def copyout(self, mv:memoryview) -> memoryview:
|
||||
mv = flat_mv(mv)
|
||||
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
|
||||
assert self.is_allocated(), "can't copyout unallocated buffer"
|
||||
self.allocator.copyout(mv, self._buf)
|
||||
return mv
|
||||
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
|
||||
assert offset < self.nbytes, "offset must be less than nbytes"
|
||||
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
||||
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
|
||||
return self._alloc(size, options if options is not None else BufferOptions())
|
||||
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
|
||||
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
|
||||
self._free(opaque, options if options is not None else BufferOptions())
|
||||
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
|
||||
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
|
||||
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
|
||||
|
||||
class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
|
||||
def alloc(self, size:int, options:Optional[BufferOptions]=None):
|
||||
if len(c := self.cache[(size, options)]): return c.pop()
|
||||
try: return super().alloc(size, options)
|
||||
except (RuntimeError, MemoryError):
|
||||
self.free_cache()
|
||||
return super().alloc(size, options)
|
||||
def free_cache(self):
|
||||
for (sz,options),opaques in self.cache.items():
|
||||
for opaque in opaques: super().free(opaque, sz, options)
|
||||
opaques.clear()
|
||||
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
|
||||
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
|
||||
|
||||
MallocAllocator = _MallocAllocator()
|
||||
|
||||
# **************** base Runner + helpers ****************
|
||||
|
||||
class Runner:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Dict, DefaultDict, Tuple, Union
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, DEBUG, dedup
|
||||
from tinygrad.ops import ScheduleItem
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast, LazyOp
|
||||
from tinygrad.device import Runner, Device
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Union, Type, Tuple, Any, List, Dict, Callable
|
||||
from typing import Union, Type, Tuple, Any, List, Dict, Callable, TYPE_CHECKING
|
||||
import functools, hashlib, math, operator, ctypes
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
@@ -7,7 +7,8 @@ from tinygrad.helpers import prod, dedup
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.buffer import Buffer
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.device import Buffer
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ctypes, collections, array, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import GraphException, round_up, to_mv
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, List, Any, cast
|
||||
import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import ctypes, subprocess, pathlib, tempfile
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
|
||||
from tinygrad.helpers import cpu_time_execution
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from dataclasses import replace
|
||||
from typing import Tuple, Optional, List
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import BufferOptions, LRUAllocator, MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator, MallocAllocator
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.assembly import PTXRenderer
|
||||
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401
|
||||
|
||||
@@ -2,8 +2,7 @@ from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io
|
||||
from typing import Optional
|
||||
from tinygrad.helpers import OSX
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.buffer import Allocator
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
|
||||
class DiskBuffer:
|
||||
def __init__(self, device:DiskDevice, size:int, offset=0):
|
||||
|
||||
@@ -4,8 +4,7 @@ import ctypes, functools, hashlib
|
||||
import tinygrad.runtime.autogen.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.buffer import BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompilerOptions
|
||||
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
|
||||
@@ -3,8 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools
|
||||
from typing import Tuple
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
|
||||
from tinygrad.helpers import DEBUG, cpu_time_execution
|
||||
from tinygrad.renderer.llvmir import uops_to_llvm_ir
|
||||
import llvmlite.binding as llvm
|
||||
|
||||
@@ -3,8 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
import Metal, libdispatch
|
||||
from typing import List, Set, Any, Tuple, Optional
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
from tinygrad.helpers import flat_mv
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.buffer import Allocator
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
|
||||
class NpyAllocator(Allocator):
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
||||
|
||||
@@ -2,8 +2,7 @@ from __future__ import annotations
|
||||
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
|
||||
from typing import Tuple, List, Any, cast
|
||||
from dataclasses import replace
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import LRUAllocator, BufferOptions
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator, BufferOptions
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes
|
||||
|
||||
@@ -5,8 +5,7 @@ from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.buffer import Allocator
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import ctypes
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.buffer import MallocAllocator
|
||||
from tinygrad.device import Compiled, MallocAllocator
|
||||
from tinygrad.runtime.ops_hsa import HSACompiler
|
||||
|
||||
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.helpers import getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
Reference in New Issue
Block a user