From 4bef1591f0bb0ee36766b8e609f23d331c6888d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Tue, 17 Oct 2023 19:33:32 +0200 Subject: [PATCH] Disable ocelot cache + fix matvec in triton (#2010) * Revert "disable flaky triton test" This reverts commit 1e15fdaee753dcf3650af288551f22b51092c001. * Update test.yml * check if has shared for matvec * disable ocelot cache for triton * disable ocelot cache * disable ocelot cache * pass shared to triton uops tests * temporary debugs for CI crash * Revert "temporary debugs for CI crash" This reverts commit fee3ea96c818e83c19b935c2f8482e0ccc91a542. * Revert "triton isn't tested, and allows this refactor (#2007)" This reverts commit dea8bb09386d9565c636bdbb024af4556e28c7b1. * add runtime_args to every renderer, move triton local size override to runtime args * Add binary to args, correct type returned * update to new loops * Update test.yml --- .github/workflows/test.yml | 6 +++--- test/test_uops.py | 4 ++-- tinygrad/codegen/optimizer.py | 2 +- tinygrad/ops.py | 5 ++++- tinygrad/renderer/cstyle.py | 4 ++-- tinygrad/renderer/llvmir.py | 6 +++--- {extra/triton => tinygrad/renderer}/triton.py | 4 ++-- tinygrad/runtime/ops_cuda.py | 17 +++++++++++------ 8 files changed, 28 insertions(+), 20 deletions(-) rename {extra/triton => tinygrad/renderer}/triton.py (98%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8417a3365b..83f77a0dfe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -227,7 +227,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, clang, gpu, cuda] #, triton] #, ptx] + backend: [llvm, clang, gpu, cuda, triton] #, ptx] name: Tests on (${{ matrix.backend }}) runs-on: ${{ matrix.backend == 'gpu' && 'ubuntu-20.04' || 'ubuntu-latest' }} @@ -259,7 +259,7 @@ jobs: sudo apt update -y sudo apt install -y --no-install-recommends git g++ cmake ninja-build llvm-15-dev zlib1g-dev libglew-dev flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc - name: Cache gpuocelot - if: matrix.backend == 'cuda' || matrix.backend == 'ptx'|| matrix.backend == 'triton' + if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' id: cache-build uses: actions/cache@v3 env: @@ -299,7 +299,7 @@ jobs: run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models - name: Run pytest (triton) if: matrix.backend=='triton' - run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models + run: python -m pytest -v -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models testunicorn: name: ARM64 unicorn Test diff --git a/test/test_uops.py b/test/test_uops.py index 477173a651..b5fc990849 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,8 +7,8 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): - src = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src[0] if getenv("TRITON") else src, [1], [1], runtime_args={"binary": getenv("TRITON")}).build(Device[Device.DEFAULT].runtime) + src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) + return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops))) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index c4d4095245..341a591dc8 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -331,7 +331,7 @@ class OptimizedKernel(Kernel): # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and \ + self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM: buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e3ab9778ab..cc1bd34854 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -233,7 +233,10 @@ class Compiled: def to_program(self, k): k.linearize() - return ASTRunner.from_linearizer(k, self.renderer(k.function_name, k.uops)).build(self.runtime, self.batch_exec) + src, runtime_args = self.renderer(k.function_name, k.uops) + return ASTRunner(k.function_name, src, k.global_size, k.local_size, + op_estimate=k.info.flops, mem_estimate=k.mem_estimate, + display_name=k.display_name, runtime_args=runtime_args).build(self.runtime, self.batch_exec) def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs): # check if we can reuse the output buffer diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 537929bea0..68ab8a8535 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -106,7 +106,7 @@ class CStyleLanguage(NamedTuple): return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" -def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: +def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: local_size: List[int] = [] kernel,prekernel,bufs = [],[],[] #pend_close = None @@ -209,4 +209,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st else: raise RuntimeError(f"failed to render {uop}") - return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel) + return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {"binary":False} diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index b1f2ab6eeb..5f16fe561c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,4 +1,4 @@ -from typing import Final, Dict, Callable, Any, List, Optional +from typing import Final, Dict, Callable, Any, List, Optional, Tuple from llvmlite import ir # type: ignore from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes @@ -57,7 +57,7 @@ def cast(bb, val, input_type, output_type): raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") -def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: +def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: # all llvm stuff goes into a module module = ir.Module(name=__file__) @@ -139,4 +139,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) bb[-1].ret_void() - return str(module) + return str(module), {"binary":False} diff --git a/extra/triton/triton.py b/tinygrad/renderer/triton.py similarity index 98% rename from extra/triton/triton.py rename to tinygrad/renderer/triton.py index 0a8ca7cdb4..0e0dcdfa4c 100644 --- a/extra/triton/triton.py +++ b/tinygrad/renderer/triton.py @@ -71,7 +71,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: - kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}+{define_scalar([], 'tl.int32', 1)}):") + kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):") depth += 1 elif uop == UOps.END: depth -= 1 elif uop == UOps.ALU: @@ -122,4 +122,4 @@ def uops_to_triton(function_name:str, uops:List[UOp]): if getenv("CUDACPU"): prg = remove_single_scalar_curly_braces(prg.split(".file")[0].split(".visible .func")[0]) max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) - return prg, local_size, {"binary":True, "shared":compiled.metadata["shared"]} + return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))} diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index fea29c1b46..69c68d9a34 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv, colored +from tinygrad.helpers import DEBUG, getenv, colored, fromimport from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions @@ -58,7 +58,7 @@ else: def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore class CUDAProgram: - def __init__(self, name:str, prg:str, binary=False, shared = 0): + def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None): if not binary: try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8') except cuda.CompileError as e: @@ -73,13 +73,13 @@ class CUDAProgram: print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) except Exception as e: print("failed to generate SASS", str(e)) # TODO: name is wrong, so we get it from the ptx using hacks - self.prg, self.shared = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared + self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override def __call__(self, global_size, local_size, *args, wait=False): if wait: start, end = cuda.Event(), cuda.Event() start.record() - self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=self.shared) + self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size if self.local_size_override is None else self.local_size_override), grid=tuple(global_size), shared=self.shared) if wait: end.record() end.synchronize() @@ -97,5 +97,10 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage( __device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {} __device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); } }; - """)) -CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize) + """)) if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm") +if getenv("TRITON") == 1: + from tinygrad.renderer.triton import uops_to_triton + renderer = uops_to_triton + CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize) +else: + CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)