bring buffer back to device (#4517)

This commit is contained in:
George Hotz
2024-05-10 11:22:31 -07:00
committed by GitHub
parent a2b707a3eb
commit d438d5698d
26 changed files with 157 additions and 174 deletions

View File

@@ -2,7 +2,7 @@
import gc
from tinygrad.helpers import prod
from tinygrad.lazy import LazyBuffer
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from tinygrad import Tensor, GlobalCounters
def print_objects():

View File

@@ -14,9 +14,8 @@ import onnx
from typing import Tuple, List, Optional, Dict, cast
from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.buffer import Buffer
from tinygrad.dtype import ImageDType
from tinygrad.device import CompiledRunner
from tinygrad.device import CompiledRunner, Buffer
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem
from tinygrad.engine.memory import memory_planner

View File

@@ -1,7 +1,7 @@
import unittest, ctypes, struct, time, array
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import to_mv
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferOptions
from tinygrad.engine.schedule import create_schedule
def _time_queue(q, d):

View File

@@ -1,7 +1,7 @@
import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.lazy import LazyBuffer

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from tinygrad.lazy import view_supported_devices
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")

View File

@@ -1,132 +0,0 @@
from __future__ import annotations
from typing import Any, Optional, Dict, Tuple
import ctypes
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.helpers import GlobalCounters, flat_mv, from_mv, getenv
from tinygrad.dtype import DType, ImageDType
@dataclass(frozen=True, eq=True)
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
cpu_access: bool = False
host: bool = False
nolru: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
if base is None:
assert offset == 0, "base buffers can't have offset"
self._base = None
self._lb_refcount = lb_refcount
if opaque is not None: self.allocate(opaque)
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
else:
assert base._base is None, "base can't have a base"
assert device == base.device, "base must have the same device"
self._base = base
if preallocate: self.allocate()
@property
def base(self) -> Buffer: return self._base if self._base is not None else self
@property
def lb_refcount(self): return self.base._lb_refcount
def ref(self, cnt): self.base._lb_refcount += cnt
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
from tinygrad.device import Device
self.allocator = Device[self.device].allocator
if self._base is not None:
self._base.ensure_allocated()
assert hasattr(self.allocator, "offset"), "offset function required for view"
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
else:
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
return self
def __reduce__(self):
buf = None
if self._base is not None:
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
if self.is_allocated():
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return
if self._base is None:
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self):
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
(">" if self.options is None else f" {self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyin to unallocated buffer"
self.allocator.copyin(self._buf, mv)
return self
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyout unallocated buffer"
self.allocator.copyout(mv, self._buf)
return mv
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
assert offset < self.nbytes, "offset must be less than nbytes"
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
return Buffer(self.device, size, dtype, base=self, offset=offset)
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int, options:Optional[BufferOptions]=None):
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size, options if options is not None else BufferOptions())
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
self._free(opaque, options if options is not None else BufferOptions())
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
class LRUAllocator(Allocator): # pylint: disable=abstract-method
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
def alloc(self, size:int, options:Optional[BufferOptions]=None):
if len(c := self.cache[(size, options)]): return c.pop()
try: return super().alloc(size, options)
except (RuntimeError, MemoryError):
self.free_cache()
return super().alloc(size, options)
def free_cache(self):
for (sz,options),opaques in self.cache.items():
for opaque in opaques: super().free(opaque, sz, options)
opaques.clear()
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
MallocAllocator = _MallocAllocator()

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import multiprocessing
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable
import importlib, inspect, functools, pathlib, os
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable, Any
import importlib, inspect, functools, pathlib, os, ctypes
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.dtype import DType, ImageDType
from tinygrad.ops import LazyOp, get_lazyop_info
from tinygrad.buffer import Buffer, Allocator
from tinygrad.codegen.uops import UOpGraph
if TYPE_CHECKING:
@@ -40,6 +41,132 @@ class _Device:
raise RuntimeError("no usable devices")
Device = _Device()
# **************** Buffer + Allocators ****************
@dataclass(frozen=True, eq=True)
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
cpu_access: bool = False
host: bool = False
nolru: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
if base is None:
assert offset == 0, "base buffers can't have offset"
self._base = None
self._lb_refcount = lb_refcount
if opaque is not None: self.allocate(opaque)
if initial_value is not None:
self.allocate()
self.copyin(memoryview(initial_value))
else:
assert base._base is None, "base can't have a base"
assert device == base.device, "base must have the same device"
self._base = base
if preallocate: self.allocate()
@property
def base(self) -> Buffer: return self._base if self._base is not None else self
@property
def lb_refcount(self): return self.base._lb_refcount
def ref(self, cnt): self.base._lb_refcount += cnt
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
self.allocator = Device[self.device].allocator
if self._base is not None:
self._base.ensure_allocated()
assert hasattr(self.allocator, "offset"), "offset function required for view"
self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
else:
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
return self
def __reduce__(self):
buf = None
if self._base is not None:
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
if self.is_allocated():
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return
if self._base is None:
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self):
return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
(f" offset:{self.offset}" if hasattr(self, "base") else "") + \
(">" if self.options is None else f" {self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyin to unallocated buffer"
self.allocator.copyin(self._buf, mv)
return self
def copyout(self, mv:memoryview) -> memoryview:
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
assert self.is_allocated(), "can't copyout unallocated buffer"
self.allocator.copyout(mv, self._buf)
return mv
def view(self, size:int, dtype:DType, offset:int) -> Buffer:
assert offset < self.nbytes, "offset must be less than nbytes"
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
return Buffer(self.device, size, dtype, base=self, offset=offset)
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int, options:Optional[BufferOptions]=None):
assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size, options if options is not None else BufferOptions())
def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
self._free(opaque, options if options is not None else BufferOptions())
def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
class LRUAllocator(Allocator): # pylint: disable=abstract-method
def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
def alloc(self, size:int, options:Optional[BufferOptions]=None):
if len(c := self.cache[(size, options)]): return c.pop()
try: return super().alloc(size, options)
except (RuntimeError, MemoryError):
self.free_cache()
return super().alloc(size, options)
def free_cache(self):
for (sz,options),opaques in self.cache.items():
for opaque in opaques: super().free(opaque, sz, options)
opaques.clear()
def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
MallocAllocator = _MallocAllocator()
# **************** base Runner + helpers ****************
class Runner:

View File

@@ -1,7 +1,7 @@
from typing import List, Dict, DefaultDict, Tuple, Union
from collections import defaultdict
from tinygrad.dtype import DType
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, DEBUG, dedup
from tinygrad.ops import ScheduleItem

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast, LazyOp
from tinygrad.device import Runner, Device
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from tinygrad.shape.symbolic import Variable, sym_infer
# **************** Runners ****************

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.buffer import Buffer
from tinygrad.device import Buffer
from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Union, Type, Tuple, Any, List, Dict, Callable
from typing import Union, Type, Tuple, Any, List, Dict, Callable, TYPE_CHECKING
import functools, hashlib, math, operator, ctypes
from enum import Enum, auto
from dataclasses import dataclass
@@ -7,7 +7,8 @@ from tinygrad.helpers import prod, dedup
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.buffer import Buffer
if TYPE_CHECKING:
from tinygrad.device import Buffer
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly

View File

@@ -1,7 +1,7 @@
import ctypes, collections, array, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import GraphException, round_up, to_mv
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer

View File

@@ -1,7 +1,7 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Tuple
from tinygrad.helpers import GraphException, init_c_var, round_up
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
from typing import Tuple, List, Any, cast
import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip

View File

@@ -1,6 +1,5 @@
import ctypes, subprocess, pathlib, tempfile
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import MallocAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
from tinygrad.helpers import cpu_time_execution
from tinygrad.renderer.cstyle import ClangRenderer

View File

@@ -5,8 +5,7 @@ from dataclasses import replace
from typing import Tuple, Optional, List
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import BufferOptions, LRUAllocator, MallocAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator, MallocAllocator
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.assembly import PTXRenderer
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401

View File

@@ -2,8 +2,7 @@ from __future__ import annotations
import os, mmap, _posixshmem, io
from typing import Optional
from tinygrad.helpers import OSX
from tinygrad.device import Compiled
from tinygrad.buffer import Allocator
from tinygrad.device import Compiled, Allocator
class DiskBuffer:
def __init__(self, device:DiskDevice, size:int, offset=0):

View File

@@ -4,8 +4,7 @@ import ctypes, functools, hashlib
import tinygrad.runtime.autogen.opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.buffer import BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompilerOptions
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
OSX_TIMING_RATIO = (125/3) if OSX else 1.0

View File

@@ -3,8 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
from tinygrad.runtime.driver.hip_comgr import compile_hip

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
import ctypes, functools
from typing import Tuple
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import MallocAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
from tinygrad.helpers import DEBUG, cpu_time_execution
from tinygrad.renderer.llvmir import uops_to_llvm_ir
import llvmlite.binding as llvm

View File

@@ -3,8 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Set, Any, Tuple, Optional
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import LRUAllocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
def wait_check(cbuf: Any):

View File

@@ -1,7 +1,6 @@
import numpy as np
from tinygrad.helpers import flat_mv
from tinygrad.device import Compiled
from tinygrad.buffer import Allocator
from tinygrad.device import Compiled, Allocator
class NpyAllocator(Allocator):
def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)

View File

@@ -2,8 +2,7 @@ from __future__ import annotations
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
from typing import Tuple, List, Any, cast
from dataclasses import replace
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import LRUAllocator, BufferOptions
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator, BufferOptions
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes

View File

@@ -5,8 +5,7 @@ from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, CompilerOptions
from tinygrad.buffer import Allocator
from tinygrad.device import Compiled, Compiler, CompilerOptions, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu

View File

@@ -1,6 +1,5 @@
import ctypes
from tinygrad.device import Compiled
from tinygrad.buffer import MallocAllocator
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.runtime.ops_hsa import HSACompiler
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")

View File

@@ -12,7 +12,7 @@ from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint, Variable, MulNode, Node
from tinygrad.engine.realize import run_schedule