move all to compile api (#2203)

* move metal+clang to compile api

* all to the new style

* remove binary arg

* fix triton

* fixup tests

* fix clang

* diskcache is generic

* __wrapped__

* compile_gpu

* fix thneed

* keep the src in the ASTRunner

* lib

* move compile_gpu

* compile_gpu in device

* put compiler in astrunner

* test reverts

* triton compiler

* ugh, that too
This commit is contained in:
George Hotz
2023-11-01 23:01:32 -07:00
committed by GitHub
parent 8932816816
commit 03cf0afa4f
18 changed files with 128 additions and 136 deletions

View File

@@ -217,7 +217,7 @@ from tinygrad.runtime.lib import RawMallocBuffer
# ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10)
# __init__ calls clang, and __call__ calls the function in the *.so outputted by clang
# in CLANG, global_size and local_size are ignored
from tinygrad.runtime.ops_clang import ClangProgram
from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
# a concrete example looks like this, this adds two size 1 RawBuffer
# first we create two numpy buffers containing 2 and 3
@@ -229,7 +229,7 @@ input_a, input_b = RawMallocBuffer.fromCPU(numpy_a), RawMallocBuffer.fromCPU(num
output = RawMallocBuffer(1, dtypes.float32)
# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size
print(output.toCPU())
assert output.toCPU()[0] == 5, "it's still 5"

View File

@@ -4,7 +4,7 @@ import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
@@ -104,21 +104,11 @@ class Thneed:
if 'data' in o:
self.buffers_to_save.add(buf)
# load in the programs (this isn't used)
prgs = {}
for k,v in jdat['programs'].items():
print("building", k)
try:
prgs[k] = CLProgram(k, v, rename=False)
except Exception:
print("FAILED", k)
traceback.print_exc()
exit(0)
# load binaries
prgs = {}
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True)
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
ptr = nptr
# populate the cl_cache
@@ -208,7 +198,7 @@ class Thneed:
# zero out the buffer
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
CLProgram("from_image_strided", """
CLProgram("from_image_strided", compile_gpu("""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
@@ -216,7 +206,7 @@ class Thneed:
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
"""), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
# multiple of 32 isn't enough
jdat['objects'].append({

View File

@@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
from tinygrad.runtime.lib import RawBuffer
class FakeProgram:
def __init__(self, name:str, prg:str, binary:bool): pass
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
class RawFakeBuffer(RawBuffer):

View File

@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
return ret.realized
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):

View File

@@ -2,36 +2,34 @@
import unittest
import secrets
import string
import tempfile
import pathlib
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import cache_compiled
import tinygrad.runtime.ops_clang
from tinygrad.helpers import diskcache
def generate_random_string(length=16):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
compile_call_count = 0
@diskcache
def helper_test_compile(prg:str) -> bytes:
global compile_call_count
compile_call_count += 1
return prg.encode()
class TestKernelCache(unittest.TestCase):
compile_call_count = 0
@cache_compiled
def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs):
self.compile_call_count += 1
return prg.encode()
def test_compile_cache(self):
prg1 = generate_random_string(64) + "a"
prg2 = generate_random_string(64) + "b"
cold_compile_res = self.__helper_test_compile(prg1)
warm_compile_res = self.__helper_test_compile(prg1)
cold_compile_res = helper_test_compile(prg1)
warm_compile_res = helper_test_compile(prg1)
assert cold_compile_res == warm_compile_res == prg1.encode()
assert self.compile_call_count == 1
assert compile_call_count == 1
prg2_res = self.__helper_test_compile(prg2)
prg2_res = helper_test_compile(prg2)
assert prg2_res == prg2.encode()
assert self.compile_call_count == 2
assert compile_call_count == 2
def test_kernel_cache_in_action(self):
if Device.DEFAULT not in ["CLANG"]:
@@ -42,15 +40,15 @@ class TestKernelCache(unittest.TestCase):
x = a + b
x.realize()
orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable
orig_compile_func = Device['CLANG'].compiler
Device['CLANG'].compiler = None # making it not callable
a1 = Tensor.rand(4,4)
b1 = Tensor.rand(4,4)
x1 = a1 + b1
x1.realize() # Same kernel should be from cache.
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func
Device['CLANG'].compiler = orig_compile_func
if __name__ == "__main__":
unittest.main()

View File

@@ -8,7 +8,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].runtime)
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))

View File

@@ -153,22 +153,12 @@ class GlobalCounters:
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# *** compiled cache decorator ***
def cache_compiled(func):
if getenv("DISABLE_COMPILER_CACHE"): return func
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(self, prg, *args, **kwargs))
return wrapper
# *** universal database cache ***
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 5
VERSION = 6
_db_connection = None
def db_connection():
global _db_connection
@@ -207,3 +197,11 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
conn.commit()
cur.close()
return val
def diskcache(func):
def wrapper(*args, **kwargs) -> bytes:
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper

View File

@@ -194,8 +194,8 @@ class GraphBatchExecutor(BasicBatchExecutor):
def exec_instance(self, instid): raise NotImplementedError("must be implemented")
class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
@@ -211,8 +211,9 @@ class ASTRunner:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
def build(self, runtime, batch_exec=BasicBatchExecutor):
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
def build(self, compiler, runtime, batch_exec=BasicBatchExecutor):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec
return self
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
@@ -243,8 +244,8 @@ class ASTRunner:
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_exec
self.method_cache: Dict[LazyOp, ASTRunner] = {}
def to_program(self, k):
@@ -252,7 +253,7 @@ class Compiled:
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.runtime, self.batch_exec)
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime, self.batch_exec)
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer

View File

@@ -209,4 +209,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
else:
raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {"binary":False}
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}

View File

@@ -144,4 +144,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
bb[-1].ret_void()
return str(module), {"binary":False}
return str(module), {}

View File

@@ -118,7 +118,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
for x in local_size: acc_local_size *= next_power_of_2(x)
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
if DEBUG >=4: print(prg)
if DEBUG >= 4: print(prg)
getlines = linecache.getlines
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
@@ -126,4 +126,5 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}

View File

@@ -1,7 +1,7 @@
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any
from tinygrad.ops import Compiled
from tinygrad.helpers import cache_compiled
from tinygrad.helpers import diskcache
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@@ -13,26 +13,24 @@ args = {
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
class ClangProgram:
def __init__(self, name:str, prg:str, binary=False):
self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
@diskcache
def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
class ClangProgram:
def __init__(self, name:str, prg:bytes):
# write to disk so we can load it
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
pathlib.Path(cached_file_path.name).write_bytes(prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
@cache_compiled
def compile(self, prg) -> bytes:
# TODO: sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()
def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
if wait: st = time.perf_counter()
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.perf_counter()-st
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional, List, Any, Tuple
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
from tinygrad.helpers import DEBUG, getenv, colored, diskcache
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions
@@ -88,9 +88,12 @@ class CUDAGraph(GraphBatchExecutor):
def exec_instance(self, instid): self.graphs[instid][0].launch()
@diskcache
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
if not binary: prg = self.compile(prg).decode('utf-8')
def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None):
prg = _prg.decode('utf-8')
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
try:
@@ -102,10 +105,6 @@ class CUDAProgram:
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
@cache_compiled
def compile(self, prg) -> bytes:
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
def __call__(self, global_size, local_size, *args, wait=False):
if wait:
start, end = cuda.Event(), cuda.Event()
@@ -118,7 +117,8 @@ class CUDAProgram:
if getenv("TRITON") == 1:
from tinygrad.renderer.triton import uops_to_triton
TritonRenderer = uops_to_triton
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), TritonRenderer, CUDAProgram, cuda.Context.synchronize)
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False),
uops_to_triton, lambda x: x.encode('utf-8'), CUDAProgram, cuda.Context.synchronize)
else:
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), CUDARenderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]),
CUDARenderer, compile_cuda, CUDAProgram, cuda.Context.synchronize, CUDAGraph)

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import os
os.environ['PYOPENCL_NO_CACHE'] = '1'
import pathlib
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
from tinygrad.ops import Compiled
from tinygrad.renderer.opencl import OpenCLRenderer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
@@ -61,23 +63,28 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
@diskcache
def compile_gpu(prg:str) -> bytes:
clprg = cl.Program(CL.cl_ctxs[0], prg)
clprg.build()
return clprg.get_info(cl.program_info.BINARIES)[0]
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
def __init__(self, name:str, prg:bytes, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs] # type: ignore
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
fromimport('disassemblers.adreno', 'disasm')(self.binary())
fromimport('disassemblers.adreno', 'disasm')(prg)
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
else:
# print the PTX for NVIDIA. TODO: probably broken for everything else
print(self.binary().decode('utf-8'))
print(prg.decode('utf-8'))
if argdtypes is not None: self.set_argdtypes(argdtypes)
def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
@staticmethod
@@ -100,4 +107,4 @@ class CLProgram:
return None
return None
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, CLProgram, CL.synchronize)
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize)

View File

@@ -2,7 +2,7 @@ import numpy as np
import ctypes, functools
import extra.hip_wrapper as hip
from typing import Tuple, Any, List
from tinygrad.helpers import DEBUG, getenv, cache_compiled
from tinygrad.helpers import DEBUG, getenv, diskcache
from tinygrad.ops import Compiled, ASTRunner, GraphBatchExecutor
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.kernel import LinearizerOptions
@@ -78,10 +78,15 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
hip.hipSetDevice(x._device)
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
@diskcache
def compile_hip(prg) -> bytes:
prog = hip.hiprtcCreateProgram(prg, "<null>", [], [])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
class HIPProgram:
def __init__(self, name:str, prg:str, binary=False):
def __init__(self, name:str, prg:bytes):
self.modules, self.prgs = [], []
prg = prg if binary else self.compile(prg, name)
if DEBUG >= 6:
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
@@ -92,12 +97,6 @@ class HIPProgram:
self.modules.append(hip.hipModuleLoadData(prg))
self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name))
@cache_compiled
def compile(self, prg, name) -> bytes:
prog = hip.hiprtcCreateProgram(prg, name, [], [])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
def __call__(self, global_size, local_size, *args, wait=False):
hip.hipSetDevice(args[0]._device)
if wait:
@@ -138,4 +137,4 @@ __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset
""",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]))
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, HIPProgram, hip.hipDeviceSynchronize, HIPGraph)
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, HIPGraph)

View File

@@ -1,7 +1,7 @@
import time, ctypes
from typing import ClassVar
from tinygrad.ops import Compiled
from tinygrad.helpers import getenv, DEBUG, cache_compiled
from tinygrad.helpers import getenv, DEBUG, diskcache
from ctypes import CFUNCTYPE
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir
@@ -9,6 +9,8 @@ from tinygrad.runtime.lib import RawMallocBuffer
import llvmlite.binding as llvm # type: ignore
LLVMOPT = bool(getenv("LLVMOPT"))
class LLVM:
target_machine: ClassVar[llvm.targets.TargetMachine] = None
engine: ClassVar[llvm.executionengine.ExecutionEngine] = None
@@ -26,7 +28,7 @@ class LLVM:
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
# TODO: this makes compile times so much faster
if getenv("LLVMOPT"):
if LLVMOPT:
llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize')
#llvm.set_option(str(), '--debug')
@@ -44,19 +46,18 @@ class LLVM:
backing_mod.triple = llvm.get_process_triple()
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
class LLVMProgram:
def __init__(self, name:str, prg:str, binary=False):
self.prg = prg if binary else self.compile(prg)
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg))
self.fxn = LLVM.engine.get_function_address(name)
@diskcache
def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes:
mod = llvm.parse_assembly(prg)
mod.verify()
LLVM().optimizer.run(mod)
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
return LLVM.target_machine.emit_object(mod)
@cache_compiled
def compile(self, prg) -> bytes:
mod = llvm.parse_assembly(prg)
mod.verify()
LLVM().optimizer.run(mod)
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
return LLVM.target_machine.emit_object(mod)
class LLVMProgram:
def __init__(self, name:str, lib:bytes):
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = LLVM.engine.get_function_address(name)
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
@@ -64,4 +65,4 @@ class LLVMProgram:
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
if wait: return time.perf_counter()-st
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram)
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)

View File

@@ -3,7 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch # type: ignore
from typing import List, Any, Tuple
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
@@ -58,9 +58,21 @@ def unwrap(x):
assert err is None, str(err)
return ret
@diskcache
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
if use_xcode:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
return pathlib.Path(output_file.name).read_bytes()
class MetalProgram:
def __init__(self, name:str, prg:str, binary:bool=False):
lib = prg if binary else self.compile(prg)
def __init__(self, name:str, lib:bytes):
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
@@ -71,19 +83,6 @@ class MetalProgram:
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
@cache_compiled
def compile(self, prg) -> bytes:
if getenv("METAL_XCODE"):
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
return pathlib.Path(output_file.name).read_bytes()
def __call__(self, global_size, local_size, *bufs, wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer()
@@ -101,4 +100,4 @@ class MetalProgram:
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, MetalProgram, METAL.synchronize, MetalBatchExecutor)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, MetalBatchExecutor)

View File

@@ -12,7 +12,7 @@ import wgpu # type: ignore
wgpu_device = get_default_device()
class WebGPUProgram:
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __call__(self, global_size, local_size, *bufs, wait=False):
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
@@ -42,4 +42,4 @@ class RawWebGPUBuffer(RawBufferCopyIn):
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)