mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Refactor triton buffer to use CLBuffer of cuda runtime (#524)
* Refactor triton buffer to use CLBuffer of runtime * Fix opencl GT0
This commit is contained in:
@@ -3,7 +3,6 @@ import hashlib
|
||||
from weakref import WeakValueDictionary
|
||||
from torch import float32
|
||||
import numpy as np
|
||||
# import pycuda.autoinit # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.autoprimaryctx # type: ignore # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
|
||||
@@ -14,6 +13,7 @@ from typing import Union, Tuple, Optional, Dict
|
||||
from tinygrad.ops import MovementOps, UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, DEBUG, GlobalCounters
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.runtime.cuda import CLBuffer
|
||||
from tinygrad.ast import ASTKernel
|
||||
|
||||
stream = cuda.Stream()
|
||||
@@ -99,23 +99,23 @@ class TritonASTKernel(ASTKernel):
|
||||
def runner(*bufs):
|
||||
GlobalCounters.global_ops += self.info.flops
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
return program[tuple(self.output_shape[::-1])](*[TritonWrapper(x.torch) for x in bufs], stream=stream.handle)
|
||||
return program[tuple(self.output_shape[::-1])](*[x.cuda for x in bufs], stream=stream.handle)
|
||||
self.func_cache[self.key] = runner
|
||||
return runner
|
||||
|
||||
class TritonBuffer(ExplicitExecAST):
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[TritonBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
|
||||
super().__init__(shape, hostbuf)
|
||||
self._buf : Optional[TritonWrapper] = hostbuf._buf if hostbuf is not None else None
|
||||
self._buf : Optional[TritonDeviceAllocation] = hostbuf._buf if hostbuf is not None else None
|
||||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
if force_create: self.torch
|
||||
if force_create: self.cuda
|
||||
|
||||
@property
|
||||
def torch(self):
|
||||
def cuda(self):
|
||||
if self._buf is None:
|
||||
self._buf = cuda.mem_alloc(4*prod(self._base_shape))
|
||||
if self._backing is not None: cuda.memcpy_htod_async(self._buf, self._backing, stream=stream)
|
||||
self._buf = TritonDeviceAllocation(4*prod(self._base_shape))
|
||||
if self._backing is not None: self._buf.copyin(self._backing, stream)
|
||||
return self._buf
|
||||
|
||||
@staticmethod
|
||||
@@ -124,8 +124,7 @@ class TritonBuffer(ExplicitExecAST):
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
buf = self.contiguous() if self._buf is not None else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
|
||||
# TODO should this be sync?
|
||||
cuda.memcpy_dtoh_async(data, buf._buf, stream=stream)
|
||||
buf._buf.copyout(data)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@@ -134,10 +133,9 @@ class TritonBuffer(ExplicitExecAST):
|
||||
k.codegen()(*k.bufs)
|
||||
return k.ret
|
||||
|
||||
class TritonWrapper:
|
||||
def __init__(self, ptr):
|
||||
self.ptr = ptr
|
||||
class TritonDeviceAllocation(CLBuffer):
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
self.dtype = float32
|
||||
|
||||
def data_ptr(self):
|
||||
return int(self.ptr)
|
||||
def data_ptr(self): return int(self.cl)
|
||||
|
||||
@@ -29,7 +29,8 @@ def split_float4(x):
|
||||
|
||||
class CLASTKernel(ASTKernel):
|
||||
code_for_op : Dict[Op, str] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.GT0: "((float)1.-step((float)0.,(-A)))",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
|
||||
UnaryOps.GT0: "(A > 0.)" if CUDA else "((float)1.-step((float)0.,(-A)))",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pycuda.autoinit # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
from typing import Optional
|
||||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import SourceModule # type: ignore
|
||||
import numpy as np
|
||||
@@ -9,7 +10,7 @@ class CLImage:
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self.cl = cuda.mem_alloc(size)
|
||||
def copyin(self, b:np.ndarray): cuda.memcpy_htod_async(self.cl, b)
|
||||
def copyin(self, b:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self.cl, b, stream)
|
||||
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self.cl)
|
||||
|
||||
class CLProgram:
|
||||
|
||||
Reference in New Issue
Block a user