Use RawCUDABuffer

This commit is contained in:
Szymon Ożóg
2023-08-15 21:00:10 +02:00
parent 41ae7cb508
commit 50003a830f

View File

@@ -15,7 +15,8 @@ from typing import Any, Union, Tuple, Optional, Dict, List, Final, Callable
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters, Compiled
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import prod, DEBUG, dtypes, ImageDType
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.runtime.ops_gpu import RawBufferCopyInOut
from tinygrad.runtime.ops_cuda import RawCUDABuffer
from tinygrad.runtime.lib import RawBuffer
from tinygrad.codegen.linearizer import LinearizerOptions, UOp, UOps, LocalBuffer
from tinygrad.shape.symbolic import NumNode
@@ -34,13 +35,6 @@ class TritonProgram:
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
self.program(*[x._buf for x in args])
class TritonDeviceAllocation(CLBuffer):
def __init__(self, size):
super().__init__(size)
self.dtype = float32
def data_ptr(self): return int(self._cl)
def uops_to_triton(function_name:str, uops:List[UOp]):
kernel = []
global_size: List[int] = []
@@ -112,4 +106,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
if DEBUG >= 4: print(prg)
return prg, global_size, local_size
TritonBuffer = Compiled(TritonDeviceAllocation, LinearizerOptions(), uops_to_triton, TritonProgram)
TritonBuffer = Compiled(RawCUDABuffer, LinearizerOptions(), uops_to_triton, TritonProgram)