mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Devicebufferless (#708)
* runs one metal kernel * conv2d works * ops tests are passing * const folding * all ops work * pre commit always passes * torch works * working still * fix graph test * tests passing * image almost works * image conv works * most images * fix custom * fix assignment * fix compile enet * clean up comments * fix realize return value * include shapetracker in LB repr * copy should make a copy * reenable method cache * fix lna * dtypes in graph * forward only for IMAGE=2 * simple realize * getting close * fixup new api, it's good except the kernel count * back to 197 kernels * tests should pass * go to a real float * no type_on_cpu * fix the docs * put shapetracker back in it's proper place
This commit is contained in:
@@ -77,14 +77,20 @@ class LazyBuffer:
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
|
||||
# a ShapeTracker is used to track things like reshapes and permutes
|
||||
# all MovementOps are zero copy in tinygrad!
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
|
||||
# if the LazyBuffer is realized, it has a RawBuffer
|
||||
# we will come back to RawBuffers later
|
||||
realized: Optional[RawBuffer]
|
||||
|
||||
# if the lazybuffer is unrealized, it has a LazyOp
|
||||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||
op: Optional[LazyOp]
|
||||
|
||||
# if the LazyBuffer is realized, it has a DeviceBuffer
|
||||
# we will come back to DeviceBuffers later, first we'll explore the LazyOp
|
||||
realized: Optional[DeviceBuffer]
|
||||
|
||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||
class LazyOp:
|
||||
@@ -128,81 +134,60 @@ assert len(lazyop.src) == 2
|
||||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
print(lazyop.src[0].op)
|
||||
assert lazyop.src[0].op.op == LoadOps.FROMCPU
|
||||
assert lazyop.src[0].op.arg[0] == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
result.lazydata.realize()
|
||||
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
|
||||
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
|
||||
assert 'ClangBuffer' in str(type(result.lazydata.realized))
|
||||
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
|
||||
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
|
||||
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
|
||||
|
||||
# %%
|
||||
# == DeviceBuffer (in tinygrad/ops.py, code 4/10) ==
|
||||
# == Union[Interpreted, Compiled] (in tinygrad/ops.py, code 5/10) ==
|
||||
|
||||
# DeviceBuffer is an abstract class to be implemented for each Device backend
|
||||
class DeviceBuffer(ABC):
|
||||
# these two are straightforward.
|
||||
# unlike LazyBuffer, there's no need for device, since that's contained in the concrete type
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
||||
|
||||
# this is the magic method that "fills" a DeviceBuffer and does all the math in tinygrad
|
||||
# NOTE: fromCPU no longer exists here, it's just a one LoadOps AST, LoadOps.FROMCPU
|
||||
def exec_ast(self, ast:LazyOp): raise NotImplementedError("must be implemented")
|
||||
# Interpreted backends are very simple (example: CPU and TORCH)
|
||||
class Interpreted:
|
||||
# they have a backing RawBuffer
|
||||
buffer: Type[RawBuffer]
|
||||
|
||||
# however, toCPU still exists. it will raise a RuntimeException if exec_ast has never been called
|
||||
# it copies out the underlying to the CPU, and will do any sync operations
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
# and they have a lookup table to functions for the Ops
|
||||
fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP: lambda x: np.exp(x),
|
||||
BinaryOps.ADD: lambda x,y: x+y}
|
||||
|
||||
# DeviceBuffers come in two flavors, InterpretedBuffer and CompiledBuffer
|
||||
# InterpretedBuffers are a lot simpler than CompiledBuffers
|
||||
# they are used to implement the CPU(numpy) and TORCH(torch) backends
|
||||
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
|
||||
import numpy as np
|
||||
import torch
|
||||
class InterpretedBuffer(DeviceBuffer):
|
||||
# this is where the data actually lives
|
||||
# finally some classes you recognize!
|
||||
_buf: Union[np.ndarray, torch.Tensor]
|
||||
# Compiled backends take a little more (example: GPU and LLVM)
|
||||
class Compiled:
|
||||
# they also have a backing RawBuffer
|
||||
buffer: Type[RawBuffer]
|
||||
|
||||
# the compute itself is defined here. these functions are called with _buf
|
||||
# here's a UnaryOp and BinaryOp from CPUBuffer(InterpretedBuffer)
|
||||
fxn_for_op: ClassVar[Dict[Op, Callable]] = {UnaryOps.EXP: lambda x: np.exp(x), BinaryOps.ADD: lambda x,y: x+y}
|
||||
|
||||
# NOTE: exec_ast should not need to be overridden!
|
||||
# The actual method lives in tinygrad/ops.py
|
||||
# it walks the LazyOp tree and calls fxn_for_op as appropriate
|
||||
|
||||
# ********** NOTE: for the CPU and TORCH backends, we are done and you can stop reading here **********
|
||||
|
||||
# %%
|
||||
# == CompiledBuffer (in tinygrad/ops.py, code 4/10) ==
|
||||
|
||||
# however, all the magic of tinygrad will come from CompiledBuffer
|
||||
# this is used for the GPU(opencl), CUDA, METAL, CLANG, and LLVM backends
|
||||
class CompiledBuffer(DeviceBuffer):
|
||||
# this is where the data actually lives, same as InterpretedBuffer
|
||||
# a RawBuffer is just raw (typed) memory on the Device in question
|
||||
_buf: RawBuffer
|
||||
|
||||
# introducing...ShapeTracker! all MovementOps are zero copy in tinygrad
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
|
||||
# NOTE: exec_ast should not need to be overridden!
|
||||
# instead you need three classes, explained below
|
||||
raw_buffer: Type[RawBuffer]
|
||||
runtime: Type[Runtime]
|
||||
# a code generator, which compiles the AST
|
||||
codegen: Type[ASTKernel]
|
||||
|
||||
# for completeness, we include RawBuffer. it's very boring and exactly what you expect
|
||||
# and a runtime, which runs the generated code
|
||||
runtime: Type[Runtime]
|
||||
|
||||
# Runtime is what actually runs the kernels for a compiled backend
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
|
||||
|
||||
# %%
|
||||
# == RawBuffer (in tinygrad/runtime/lib.py, code 5/10) ==
|
||||
import numpy as np
|
||||
|
||||
# RawBuffer is where the data is actualy held. it's pretty close to just memory
|
||||
class RawBuffer(ABC):
|
||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||
def __init__(self, size:int, dtype:DType): raise NotImplementedError("must be implemented")
|
||||
# `buf` is an opaque container class
|
||||
def __init__(self, size:int, dtype:DType, buf:Any): raise NotImplementedError("must be implemented")
|
||||
|
||||
# fromCPU is classmethod that creates a RawBuffer, it's a classmethod since some runtimes are 0 copy
|
||||
@classmethod
|
||||
@@ -211,13 +196,14 @@ class RawBuffer(ABC):
|
||||
# toCPU converts the RawBuffer to a numpy array with shape (size,). many backends are 0 copy here
|
||||
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
|
||||
# Runtime is what actually runs the kernels
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
|
||||
# RawNumpyBuffer is a RawBuffer example for numpy. It's very simple
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
# NOTE: the "np.ndarray" is stored in the opaque container
|
||||
def __init__(self, buf:np.ndarray):
|
||||
super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x)
|
||||
def toCPU(self): return self._buf
|
||||
|
||||
# %%
|
||||
# == Example: 2+3 in raw clang ==
|
||||
@@ -262,11 +248,11 @@ class ASTKernel:
|
||||
def __init__(self, ast:LazyOp): pass
|
||||
def codegen(self) -> ASTRunner: pass
|
||||
|
||||
# we return a class that runs code on CompiledBuffers
|
||||
# we return a class that runs code on LazyBuffers, which are all expected to be realized
|
||||
class ASTRunner: # (from tinygrad/ops.py)
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
|
||||
def build(self, runtime:Runtime): pass
|
||||
def exec(self, bufs:List[CompiledBuffer]): pass
|
||||
def exec(self, bufs:List[LazyBuffer]): pass
|
||||
|
||||
# that hides a lot of complexity that will be refactored, but that's the basic idea of code generation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user