mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Use RawCUDABuffer
This commit is contained in:
@@ -15,7 +15,8 @@ from typing import Any, Union, Tuple, Optional, Dict, List, Final, Callable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters, Compiled
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import prod, DEBUG, dtypes, ImageDType
|
||||
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||
from tinygrad.runtime.ops_gpu import RawBufferCopyInOut
|
||||
from tinygrad.runtime.ops_cuda import RawCUDABuffer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions, UOp, UOps, LocalBuffer
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
@@ -34,13 +35,6 @@ class TritonProgram:
|
||||
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
|
||||
self.program(*[x._buf for x in args])
|
||||
|
||||
class TritonDeviceAllocation(CLBuffer):
|
||||
def __init__(self, size):
|
||||
super().__init__(size)
|
||||
self.dtype = float32
|
||||
|
||||
def data_ptr(self): return int(self._cl)
|
||||
|
||||
def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
kernel = []
|
||||
global_size: List[int] = []
|
||||
@@ -112,4 +106,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
if DEBUG >= 4: print(prg)
|
||||
return prg, global_size, local_size
|
||||
|
||||
TritonBuffer = Compiled(TritonDeviceAllocation, LinearizerOptions(), uops_to_triton, TritonProgram)
|
||||
TritonBuffer = Compiled(RawCUDABuffer, LinearizerOptions(), uops_to_triton, TritonProgram)
|
||||
Reference in New Issue
Block a user