Refactor triton buffer to use CLBuffer of cuda runtime (#524)

* Refactor triton buffer to use CLBuffer of runtime

* Fix opencl GT0
This commit is contained in:
Martin Loretz
2023-02-04 05:02:41 +01:00
committed by GitHub
parent ad4f6aa2cf
commit 4ad67b4bbc
3 changed files with 17 additions and 17 deletions

View File

@@ -3,7 +3,6 @@ import hashlib
from weakref import WeakValueDictionary
from torch import float32
import numpy as np
# import pycuda.autoinit # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.autoprimaryctx # type: ignore # noqa: F401
import pycuda.driver as cuda # type: ignore
@@ -14,6 +13,7 @@ from typing import Union, Tuple, Optional, Dict
from tinygrad.ops import MovementOps, UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, DEBUG, GlobalCounters
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import prod
from tinygrad.runtime.cuda import CLBuffer
from tinygrad.ast import ASTKernel
stream = cuda.Stream()
@@ -99,23 +99,23 @@ class TritonASTKernel(ASTKernel):
def runner(*bufs):
GlobalCounters.global_ops += self.info.flops
GlobalCounters.global_mem += mem_estimate
return program[tuple(self.output_shape[::-1])](*[TritonWrapper(x.torch) for x in bufs], stream=stream.handle)
return program[tuple(self.output_shape[::-1])](*[x.cuda for x in bufs], stream=stream.handle)
self.func_cache[self.key] = runner
return runner
class TritonBuffer(ExplicitExecAST):
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[TritonBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
super().__init__(shape, hostbuf)
self._buf : Optional[TritonWrapper] = hostbuf._buf if hostbuf is not None else None
self._buf : Optional[TritonDeviceAllocation] = hostbuf._buf if hostbuf is not None else None
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
if force_create: self.torch
if force_create: self.cuda
@property
def torch(self):
def cuda(self):
if self._buf is None:
self._buf = cuda.mem_alloc(4*prod(self._base_shape))
if self._backing is not None: cuda.memcpy_htod_async(self._buf, self._backing, stream=stream)
self._buf = TritonDeviceAllocation(4*prod(self._base_shape))
if self._backing is not None: self._buf.copyin(self._backing, stream)
return self._buf
@staticmethod
@@ -124,8 +124,7 @@ class TritonBuffer(ExplicitExecAST):
def toCPU(self):
data = np.empty(self.shape, dtype=np.float32)
buf = self.contiguous() if self._buf is not None else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
# TODO should this be sync?
cuda.memcpy_dtoh_async(data, buf._buf, stream=stream)
buf._buf.copyout(data)
return data
@classmethod
@@ -134,10 +133,9 @@ class TritonBuffer(ExplicitExecAST):
k.codegen()(*k.bufs)
return k.ret
class TritonWrapper:
def __init__(self, ptr):
self.ptr = ptr
class TritonDeviceAllocation(CLBuffer):
def __init__(self, size):
super().__init__(size)
self.dtype = float32
def data_ptr(self):
return int(self.ptr)
def data_ptr(self): return int(self.cl)

View File

@@ -29,7 +29,8 @@ def split_float4(x):
class CLASTKernel(ASTKernel):
code_for_op : Dict[Op, str] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.GT0: "((float)1.-step((float)0.,(-A)))",
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
UnaryOps.GT0: "(A > 0.)" if CUDA else "((float)1.-step((float)0.,(-A)))",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",

View File

@@ -1,4 +1,5 @@
import pycuda.autoinit # type: ignore # pylint: disable=unused-import # noqa: F401
from typing import Optional
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import SourceModule # type: ignore
import numpy as np
@@ -9,7 +10,7 @@ class CLImage:
class CLBuffer:
def __init__(self, size): self.cl = cuda.mem_alloc(size)
def copyin(self, b:np.ndarray): cuda.memcpy_htod_async(self.cl, b)
def copyin(self, b:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self.cl, b, stream)
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self.cl)
class CLProgram: