diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 46566c15e6..1f991a2fc0 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -15,7 +15,8 @@ from typing import Any, Union, Tuple, Optional, Dict, List, Final, Callable from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters, Compiled from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import prod, DEBUG, dtypes, ImageDType -from tinygrad.runtime.ops_gpu import CLBuffer +from tinygrad.runtime.ops_gpu import RawBufferCopyInOut +from tinygrad.runtime.ops_cuda import RawCUDABuffer from tinygrad.runtime.lib import RawBuffer from tinygrad.codegen.linearizer import LinearizerOptions, UOp, UOps, LocalBuffer from tinygrad.shape.symbolic import NumNode @@ -34,13 +35,6 @@ class TritonProgram: def __call__(self, global_size, local_size, *args, wait=False) -> Any: self.program(*[x._buf for x in args]) -class TritonDeviceAllocation(CLBuffer): - def __init__(self, size): - super().__init__(size) - self.dtype = float32 - - def data_ptr(self): return int(self._cl) - def uops_to_triton(function_name:str, uops:List[UOp]): kernel = [] global_size: List[int] = [] @@ -112,4 +106,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]): if DEBUG >= 4: print(prg) return prg, global_size, local_size -TritonBuffer = Compiled(TritonDeviceAllocation, LinearizerOptions(), uops_to_triton, TritonProgram) \ No newline at end of file +TritonBuffer = Compiled(RawCUDABuffer, LinearizerOptions(), uops_to_triton, TritonProgram) \ No newline at end of file