mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move all to compile api (#2203)
* move metal+clang to compile api * all to the new style * remove binary arg * fix triton * fixup tests * fix clang * diskcache is generic * __wrapped__ * compile_gpu * fix thneed * keep the src in the ASTRunner * lib * move compile_gpu * compile_gpu in device * put compiler in astrunner * test reverts * triton compiler * ugh, that too
This commit is contained in:
@@ -217,7 +217,7 @@ from tinygrad.runtime.lib import RawMallocBuffer
|
||||
# ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10)
|
||||
# __init__ calls clang, and __call__ calls the function in the *.so outputted by clang
|
||||
# in CLANG, global_size and local_size are ignored
|
||||
from tinygrad.runtime.ops_clang import ClangProgram
|
||||
from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
|
||||
|
||||
# a concrete example looks like this, this adds two size 1 RawBuffer
|
||||
# first we create two numpy buffers containing 2 and 3
|
||||
@@ -229,7 +229,7 @@ input_a, input_b = RawMallocBuffer.fromCPU(numpy_a), RawMallocBuffer.fromCPU(num
|
||||
output = RawMallocBuffer(1, dtypes.float32)
|
||||
|
||||
# compile the program, run it, and 2+3 does indeed equal 5
|
||||
program = ClangProgram("add", f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
|
||||
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
|
||||
program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size
|
||||
print(output.toCPU())
|
||||
assert output.toCPU()[0] == 5, "it's still 5"
|
||||
|
||||
@@ -4,7 +4,7 @@ import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
@@ -104,21 +104,11 @@ class Thneed:
|
||||
if 'data' in o:
|
||||
self.buffers_to_save.add(buf)
|
||||
|
||||
# load in the programs (this isn't used)
|
||||
prgs = {}
|
||||
for k,v in jdat['programs'].items():
|
||||
print("building", k)
|
||||
try:
|
||||
prgs[k] = CLProgram(k, v, rename=False)
|
||||
except Exception:
|
||||
print("FAILED", k)
|
||||
traceback.print_exc()
|
||||
exit(0)
|
||||
|
||||
# load binaries
|
||||
prgs = {}
|
||||
for o in jdat['binaries']:
|
||||
nptr = ptr + o['length']
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True)
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
|
||||
ptr = nptr
|
||||
|
||||
# populate the cl_cache
|
||||
@@ -208,7 +198,7 @@ class Thneed:
|
||||
# zero out the buffer
|
||||
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
|
||||
|
||||
CLProgram("from_image_strided", """
|
||||
CLProgram("from_image_strided", compile_gpu("""
|
||||
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l;
|
||||
@@ -216,7 +206,7 @@ class Thneed:
|
||||
l.x = get_global_id(0);
|
||||
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
|
||||
}
|
||||
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
|
||||
"""), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
|
||||
|
||||
# multiple of 32 isn't enough
|
||||
jdat['objects'].append({
|
||||
|
||||
2
test/external/external_test_speed_llama.py
vendored
2
test/external/external_test_speed_llama.py
vendored
@@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
class FakeProgram:
|
||||
def __init__(self, name:str, prg:str, binary:bool): pass
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False): pass
|
||||
|
||||
class RawFakeBuffer(RawBuffer):
|
||||
|
||||
@@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
|
||||
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
|
||||
return ret.realized
|
||||
|
||||
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
|
||||
@@ -2,36 +2,34 @@
|
||||
import unittest
|
||||
import secrets
|
||||
import string
|
||||
import tempfile
|
||||
import pathlib
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import cache_compiled
|
||||
import tinygrad.runtime.ops_clang
|
||||
from tinygrad.helpers import diskcache
|
||||
|
||||
def generate_random_string(length=16):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
compile_call_count = 0
|
||||
|
||||
@diskcache
|
||||
def helper_test_compile(prg:str) -> bytes:
|
||||
global compile_call_count
|
||||
compile_call_count += 1
|
||||
return prg.encode()
|
||||
|
||||
class TestKernelCache(unittest.TestCase):
|
||||
compile_call_count = 0
|
||||
|
||||
@cache_compiled
|
||||
def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs):
|
||||
self.compile_call_count += 1
|
||||
return prg.encode()
|
||||
|
||||
def test_compile_cache(self):
|
||||
prg1 = generate_random_string(64) + "a"
|
||||
prg2 = generate_random_string(64) + "b"
|
||||
cold_compile_res = self.__helper_test_compile(prg1)
|
||||
warm_compile_res = self.__helper_test_compile(prg1)
|
||||
cold_compile_res = helper_test_compile(prg1)
|
||||
warm_compile_res = helper_test_compile(prg1)
|
||||
assert cold_compile_res == warm_compile_res == prg1.encode()
|
||||
assert self.compile_call_count == 1
|
||||
assert compile_call_count == 1
|
||||
|
||||
prg2_res = self.__helper_test_compile(prg2)
|
||||
prg2_res = helper_test_compile(prg2)
|
||||
assert prg2_res == prg2.encode()
|
||||
assert self.compile_call_count == 2
|
||||
assert compile_call_count == 2
|
||||
|
||||
def test_kernel_cache_in_action(self):
|
||||
if Device.DEFAULT not in ["CLANG"]:
|
||||
@@ -42,15 +40,15 @@ class TestKernelCache(unittest.TestCase):
|
||||
x = a + b
|
||||
x.realize()
|
||||
|
||||
orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile
|
||||
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable
|
||||
orig_compile_func = Device['CLANG'].compiler
|
||||
Device['CLANG'].compiler = None # making it not callable
|
||||
|
||||
a1 = Tensor.rand(4,4)
|
||||
b1 = Tensor.rand(4,4)
|
||||
x1 = a1 + b1
|
||||
x1.realize() # Same kernel should be from cache.
|
||||
|
||||
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func
|
||||
Device['CLANG'].compiler = orig_compile_func
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
|
||||
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].runtime)
|
||||
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
|
||||
|
||||
@@ -153,22 +153,12 @@ class GlobalCounters:
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
||||
|
||||
# *** compiled cache decorator ***
|
||||
|
||||
def cache_compiled(func):
|
||||
if getenv("DISABLE_COMPILER_CACHE"): return func
|
||||
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
|
||||
table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
|
||||
if (ret:=diskcache_get(table, key)): return ret
|
||||
return diskcache_put(table, key, func(self, prg, *args, **kwargs))
|
||||
return wrapper
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 5
|
||||
VERSION = 6
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
@@ -207,3 +197,11 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
||||
def diskcache(func):
|
||||
def wrapper(*args, **kwargs) -> bytes:
|
||||
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
||||
if (ret:=diskcache_get(table, key)): return ret
|
||||
return diskcache_put(table, key, func(*args, **kwargs))
|
||||
setattr(wrapper, "__wrapped__", func)
|
||||
return wrapper
|
||||
|
||||
@@ -194,8 +194,8 @@ class GraphBatchExecutor(BasicBatchExecutor):
|
||||
def exec_instance(self, instid): raise NotImplementedError("must be implemented")
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
|
||||
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
|
||||
@@ -211,8 +211,9 @@ class ASTRunner:
|
||||
return float('inf')
|
||||
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||
|
||||
def build(self, runtime, batch_exec=BasicBatchExecutor):
|
||||
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
|
||||
def build(self, compiler, runtime, batch_exec=BasicBatchExecutor):
|
||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||
self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec
|
||||
return self
|
||||
|
||||
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
@@ -243,8 +244,8 @@ class ASTRunner:
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_exec
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
|
||||
def to_program(self, k):
|
||||
@@ -252,7 +253,7 @@ class Compiled:
|
||||
src, runtime_args = self.renderer(k.function_name, k.uops)
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.runtime, self.batch_exec)
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime, self.batch_exec)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
|
||||
@@ -209,4 +209,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {"binary":False}
|
||||
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}
|
||||
|
||||
@@ -144,4 +144,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module), {"binary":False}
|
||||
return str(module), {}
|
||||
|
||||
@@ -118,7 +118,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
for x in local_size: acc_local_size *= next_power_of_2(x)
|
||||
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
|
||||
|
||||
if DEBUG >=4: print(prg)
|
||||
if DEBUG >= 4: print(prg)
|
||||
getlines = linecache.getlines
|
||||
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
|
||||
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
|
||||
@@ -126,4 +126,5 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
||||
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
||||
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
||||
return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
|
||||
|
||||
return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
|
||||
from typing import Any
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import cache_compiled
|
||||
from tinygrad.helpers import diskcache
|
||||
from tinygrad.runtime.lib import RawMallocBuffer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
@@ -13,26 +13,24 @@ args = {
|
||||
|
||||
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
|
||||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
|
||||
@diskcache
|
||||
def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes:
|
||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8'))
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:bytes):
|
||||
# write to disk so we can load it
|
||||
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
|
||||
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
|
||||
pathlib.Path(cached_file_path.name).write_bytes(prg)
|
||||
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
# TODO: sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
|
||||
if wait: return time.perf_counter()-st
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
|
||||
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Optional, List, Any, Tuple
|
||||
import numpy as np
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
|
||||
from tinygrad.helpers import DEBUG, getenv, colored, diskcache
|
||||
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -88,9 +88,12 @@ class CUDAGraph(GraphBatchExecutor):
|
||||
|
||||
def exec_instance(self, instid): self.graphs[instid][0].launch()
|
||||
|
||||
@diskcache
|
||||
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
|
||||
if not binary: prg = self.compile(prg).decode('utf-8')
|
||||
def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None):
|
||||
prg = _prg.decode('utf-8')
|
||||
if DEBUG >= 5: print(pretty_ptx(prg))
|
||||
if DEBUG >= 6:
|
||||
try:
|
||||
@@ -102,10 +105,6 @@ class CUDAProgram:
|
||||
# TODO: name is wrong, so we get it from the ptx using hacks
|
||||
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
if wait:
|
||||
start, end = cuda.Event(), cuda.Event()
|
||||
@@ -118,7 +117,8 @@ class CUDAProgram:
|
||||
|
||||
if getenv("TRITON") == 1:
|
||||
from tinygrad.renderer.triton import uops_to_triton
|
||||
TritonRenderer = uops_to_triton
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), TritonRenderer, CUDAProgram, cuda.Context.synchronize)
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False),
|
||||
uops_to_triton, lambda x: x.encode('utf-8'), CUDAProgram, cuda.Context.synchronize)
|
||||
else:
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), CUDARenderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
|
||||
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]),
|
||||
CUDARenderer, compile_cuda, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
os.environ['PYOPENCL_NO_CACHE'] = '1'
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.renderer.opencl import OpenCLRenderer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
@@ -61,23 +63,28 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
|
||||
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
|
||||
|
||||
@diskcache
|
||||
def compile_gpu(prg:str) -> bytes:
|
||||
clprg = cl.Program(CL.cl_ctxs[0], prg)
|
||||
clprg.build()
|
||||
return clprg.get_info(cl.program_info.BINARIES)[0]
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
|
||||
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore
|
||||
def __init__(self, name:str, prg:bytes, argdtypes=None, options=None):
|
||||
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs] # type: ignore
|
||||
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
|
||||
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
|
||||
if DEBUG >= 5 and not OSX:
|
||||
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
|
||||
fromimport('disassemblers.adreno', 'disasm')(self.binary())
|
||||
fromimport('disassemblers.adreno', 'disasm')(prg)
|
||||
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg))
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
else:
|
||||
# print the PTX for NVIDIA. TODO: probably broken for everything else
|
||||
print(self.binary().decode('utf-8'))
|
||||
print(prg.decode('utf-8'))
|
||||
if argdtypes is not None: self.set_argdtypes(argdtypes)
|
||||
|
||||
def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
|
||||
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
|
||||
|
||||
@staticmethod
|
||||
@@ -100,4 +107,4 @@ class CLProgram:
|
||||
return None
|
||||
return None
|
||||
|
||||
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, CLProgram, CL.synchronize)
|
||||
GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize)
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import ctypes, functools
|
||||
import extra.hip_wrapper as hip
|
||||
from typing import Tuple, Any, List
|
||||
from tinygrad.helpers import DEBUG, getenv, cache_compiled
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache
|
||||
from tinygrad.ops import Compiled, ASTRunner, GraphBatchExecutor
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
@@ -78,10 +78,15 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
|
||||
|
||||
@diskcache
|
||||
def compile_hip(prg) -> bytes:
|
||||
prog = hip.hiprtcCreateProgram(prg, "<null>", [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
def __init__(self, name:str, prg:bytes):
|
||||
self.modules, self.prgs = [], []
|
||||
prg = prg if binary else self.compile(prg, name)
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
|
||||
@@ -92,12 +97,6 @@ class HIPProgram:
|
||||
self.modules.append(hip.hipModuleLoadData(prg))
|
||||
self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name))
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg, name) -> bytes:
|
||||
prog = hip.hiprtcCreateProgram(prg, name, [], [])
|
||||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
hip.hipSetDevice(args[0]._device)
|
||||
if wait:
|
||||
@@ -138,4 +137,4 @@ __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset
|
||||
""",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]))
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, HIPProgram, hip.hipDeviceSynchronize, HIPGraph)
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, HIPGraph)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time, ctypes
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import getenv, DEBUG, cache_compiled
|
||||
from tinygrad.helpers import getenv, DEBUG, diskcache
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.llvmir import uops_to_llvm_ir
|
||||
@@ -9,6 +9,8 @@ from tinygrad.runtime.lib import RawMallocBuffer
|
||||
|
||||
import llvmlite.binding as llvm # type: ignore
|
||||
|
||||
LLVMOPT = bool(getenv("LLVMOPT"))
|
||||
|
||||
class LLVM:
|
||||
target_machine: ClassVar[llvm.targets.TargetMachine] = None
|
||||
engine: ClassVar[llvm.executionengine.ExecutionEngine] = None
|
||||
@@ -26,7 +28,7 @@ class LLVM:
|
||||
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
|
||||
|
||||
# TODO: this makes compile times so much faster
|
||||
if getenv("LLVMOPT"):
|
||||
if LLVMOPT:
|
||||
llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
|
||||
if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize')
|
||||
#llvm.set_option(str(), '--debug')
|
||||
@@ -44,19 +46,18 @@ class LLVM:
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
|
||||
|
||||
class LLVMProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
self.prg = prg if binary else self.compile(prg)
|
||||
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg))
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
@diskcache
|
||||
def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes:
|
||||
mod = llvm.parse_assembly(prg)
|
||||
mod.verify()
|
||||
LLVM().optimizer.run(mod)
|
||||
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
|
||||
return LLVM.target_machine.emit_object(mod)
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
mod = llvm.parse_assembly(prg)
|
||||
mod.verify()
|
||||
LLVM().optimizer.run(mod)
|
||||
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
|
||||
return LLVM.target_machine.emit_object(mod)
|
||||
class LLVMProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
|
||||
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
|
||||
@@ -64,4 +65,4 @@ class LLVMProgram:
|
||||
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
|
||||
if wait: return time.perf_counter()-st
|
||||
|
||||
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram)
|
||||
LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any, Tuple
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache
|
||||
from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
|
||||
@@ -58,9 +58,21 @@ def unwrap(x):
|
||||
assert err is None, str(err)
|
||||
return ret
|
||||
|
||||
@diskcache
|
||||
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
|
||||
if use_xcode:
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
# TODO: avoid file write here?
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, name:str, prg:str, binary:bool=False):
|
||||
lib = prg if binary else self.compile(prg)
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
||||
self.fxn = self.library.newFunctionWithName_(name)
|
||||
@@ -71,19 +83,6 @@ class MetalProgram:
|
||||
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
||||
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
||||
|
||||
@cache_compiled
|
||||
def compile(self, prg) -> bytes:
|
||||
if getenv("METAL_XCODE"):
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
# TODO: avoid file write here?
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None)
|
||||
return pathlib.Path(output_file.name).read_bytes()
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
@@ -101,4 +100,4 @@ class MetalProgram:
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, MetalProgram, METAL.synchronize, MetalBatchExecutor)
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, MetalBatchExecutor)
|
||||
|
||||
@@ -12,7 +12,7 @@ import wgpu # type: ignore
|
||||
wgpu_device = get_default_device()
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
|
||||
def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
|
||||
@@ -42,4 +42,4 @@ class RawWebGPUBuffer(RawBufferCopyIn):
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram)
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)
|
||||
|
||||
Reference in New Issue
Block a user