mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
* Move ops_triton to runtime and remove errors from deprecated code * Remove deprecated AST Kernel * Remove deprecated buffer * Add TritonProgram * Triton Buffer * Use RawCUDABuffer * triton_compile * Added new parameter * pass _buf to program * remove deprecated include * Added triton tests * Deprecated includes removed * remove double print * Disable float4 support * Disable float4 support * variable load fix * Track local size * Add pycuda to triton dependencies * Merge test.yml * install cuda packages for testing * merge double package install * remove emulated from triton tests * upscale local index to power of 2 and add masking * cuda envs * Add TernaryOps * ConstOp loading * proper function name * remove deprecated variables * get global program from name * const ops match local shape * Enable test_nn * remove deprecated import * fix linter error * Add wait logic * Add local size override * accumulate local shapes instead of using max shape * Merge triton tests into global tests * fix envs in testing * Old testing routine * split file into renderer and program * remove print and starting whitespace * pretty ptx print on debug 5 * linter errors * ignore triton saturation tests * ignore test example * remove pytorch cpu extra index * Add triton to existing testing routine * use triton tests * disable cuda backend in triton tests * use cudacpu in tests * print used device * Print device default * Remove print * ensure we are running triton backend * update variable signatures * update dtypes for load * infinity render fixed * limit global size * negative infinity now properly rendered * split chain with parentheses for and node * Add option to disable shared memory, disable for triton * missing import * Properly index and mask conditional load * use mask only if not loading a block pointer * nan support * fix symbolic tests to include chain split * proper masking for stores * Implemented bool dtype * Add mod * fix loads for variables with valid range * merge triton with cuda runtime * merge from master * run triton tests with cuda * Correct target when running from triton * conftest with triton compiler config * use triton nightly * verbose tests for triton * capture stdout * fix function depth when exiting multiple loops * add render valid function for readabilty * fix mask for local loops * add _arg_int32 datatype * fix dims for conditional loads * enable non float stores * correct variable dtypes * fix type for arg_int32 * remove junk * Added get max function for range based var.max * remove deprecated code * Fix triton ptxas path * Fix testing for CI * clamp local size by max local size instead of always running max * Disable matmul test in triton cpu * rerun tests * Disable broken test in triton cpu * whitespace removed * rerun tests again * Disable TestSymbolicOps for triton * update to new uops * linter fix * ignore test/extra * linting fix * Update tinygrad/renderer/triton.py Co-authored-by: Gijs Koning <gijs-koning@live.nl> * remove deprecated line * quotes type fix * linter * Remove unnecesary lines * UnaryOps.NEG * dont define constants * Linting fix * Disable tests that are broken in ocelot * remove trailing whitespace * reduce line count * linting fix * update to new uast * New looping style * Update to new uast * make AST runner work with triton * linting fix * set renderer var for testing * disable local for ocelot * reenable all tests for ocelot * Pass shared to cuda * Don't group if the backend doesn't support shared mem * use working gpuocelot branch * enable all tests * enable local for ocelot * cleanup * Update test.yml * update cache key * reenable test symbolic and extra * Update test.yml * Revert "Update test.yml" (rerun tests) This reverts commit98c0630ee5. * Revert "fix symbolic tests to include chain split" This reverts commit22a9a4c9cd. * Revert "split chain with parentheses for and node" This reverts commit7499a7004e. * use global size from linearizer * rename newvar to dtype to match other renderers * join program start lines * simplify code that adds axis to local dims * assign r[u] in ssa * We no longer need to replace target in src * we no longer need to cast indices to int by hand * Update triton.py(rerun tests) * Update triton.py(rerun tests) * Update triton.py(rerun tests) --------- Co-authored-by: Gijs Koning <gijs-koning@live.nl> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
106 lines
6.6 KiB
Python
106 lines
6.6 KiB
Python
import subprocess, time, re, hashlib, tempfile, functools
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
import numpy as np
|
|
from pycuda.compiler import compile as cuda_compile # type: ignore
|
|
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
|
|
from tinygrad.ops import Compiled
|
|
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
|
|
from tinygrad.codegen.kernel import LinearizerOptions
|
|
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
|
|
|
def pretty_ptx(s):
|
|
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
|
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
|
|
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
|
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
|
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
|
|
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
|
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
|
return s
|
|
def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()])
|
|
|
|
if getenv("CUDACPU", 0) == 1:
|
|
import ctypes, ctypes.util
|
|
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
|
lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
class cuda:
|
|
class module:
|
|
def __init__(self, src): self.src = src
|
|
def get_function(self, _): return self
|
|
def __call__(self, *args, block, grid, shared): lib.ptx_run(self.src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), *block, *grid, shared)
|
|
module_from_buffer = lambda src: cuda.module(src) # pylint: disable=unnecessary-lambda # noqa: E731
|
|
class Event:
|
|
def __init__(self): pass
|
|
def record(self): self.start = time.perf_counter()
|
|
def time_till(self, other): return self.start - other.start
|
|
def synchronize(self): pass
|
|
class Context:
|
|
synchronize = lambda:0 # noqa: E731
|
|
CompileError = Exception
|
|
class context:
|
|
class device:
|
|
compute_capability = lambda: (3,5) # pylint: disable=unnecessary-lambda # noqa: E731
|
|
get_device = lambda: context.device # pylint: disable=unnecessary-lambda # noqa: E731
|
|
import pycuda.driver # type: ignore
|
|
pycuda.driver.Context = context
|
|
RawCUDABuffer = RawMallocBuffer
|
|
else:
|
|
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
|
import pycuda.driver as cuda # type: ignore
|
|
class CUDAAllocator(LRUAllocator):
|
|
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
|
|
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
|
|
CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory())
|
|
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
|
|
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
|
|
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
|
|
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
|
|
|
|
class CUDAProgram:
|
|
def __init__(self, name:str, prg:str, binary=False, shared = 0):
|
|
if not binary:
|
|
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
|
|
except cuda.CompileError as e:
|
|
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
|
raise e
|
|
if DEBUG >= 5: print(pretty_ptx(prg))
|
|
if DEBUG >= 6:
|
|
try:
|
|
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}").as_posix()
|
|
with open(fn + ".ptx", "wb") as f: f.write(prg.encode('utf-8'))
|
|
subprocess.run(["ptxas", f"-arch={arch()}", "-o", fn, fn+".ptx"], check=True)
|
|
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
|
except Exception as e: print("failed to generate SASS", str(e))
|
|
# TODO: name is wrong, so we get it from the ptx using hacks
|
|
self.prg, self.shared = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared
|
|
|
|
def __call__(self, global_size, local_size, *args, wait=False):
|
|
if wait:
|
|
start, end = cuda.Event(), cuda.Event()
|
|
start.record()
|
|
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=self.shared)
|
|
if wait:
|
|
end.record()
|
|
end.synchronize()
|
|
return start.time_till(end)*1e-3
|
|
|
|
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
|
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
|
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
|
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
|
half_prekernel = """
|
|
#include <cuda_fp16.h>
|
|
struct __align__(8) half4 {
|
|
half2 x, y;
|
|
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
|
|
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
|
|
};
|
|
""")) if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm")
|
|
if getenv("TRITON") == 1:
|
|
from tinygrad.renderer.triton import uops_to_triton
|
|
renderer = uops_to_triton
|
|
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize)
|
|
else:
|
|
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)
|