From 50003a830f9b7bcf995ab983133479e638d9b483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Tue, 15 Aug 2023 21:00:10 +0200 Subject: [PATCH] Use RawCUDABuffer --- tinygrad/runtime/ops_triton.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 46566c15e6..1f991a2fc0 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -15,7 +15,8 @@ from typing import Any, Union, Tuple, Optional, Dict, List, Final, Callable from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, GlobalCounters, Compiled from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import prod, DEBUG, dtypes, ImageDType -from tinygrad.runtime.ops_gpu import CLBuffer +from tinygrad.runtime.ops_gpu import RawBufferCopyInOut +from tinygrad.runtime.ops_cuda import RawCUDABuffer from tinygrad.runtime.lib import RawBuffer from tinygrad.codegen.linearizer import LinearizerOptions, UOp, UOps, LocalBuffer from tinygrad.shape.symbolic import NumNode @@ -34,13 +35,6 @@ class TritonProgram: def __call__(self, global_size, local_size, *args, wait=False) -> Any: self.program(*[x._buf for x in args]) -class TritonDeviceAllocation(CLBuffer): - def __init__(self, size): - super().__init__(size) - self.dtype = float32 - - def data_ptr(self): return int(self._cl) - def uops_to_triton(function_name:str, uops:List[UOp]): kernel = [] global_size: List[int] = [] @@ -112,4 +106,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]): if DEBUG >= 4: print(prg) return prg, global_size, local_size -TritonBuffer = Compiled(TritonDeviceAllocation, LinearizerOptions(), uops_to_triton, TritonProgram) \ No newline at end of file +TritonBuffer = Compiled(RawCUDABuffer, LinearizerOptions(), uops_to_triton, TritonProgram) \ No newline at end of file