mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Refactor ASTs (#622)
* ugh worst branch name * compiler refactor continues * scc -> cloc * buf -> _buf * finish _buf, and program -> runtime * gpu is still working, clang isn't * clang in new style * ops_metal * something broke it * improve metal * clean up tons of cl crap * hack fix sync * cleaner gpu * gpu metal clang * cleanups * minor refactor * GPUCodegen * fix up LLVM * blind CUDA refactor * codegen / runtime * keep ops naming * linter passes * woah, llvm was allocing 4x what it needed to * bugfixes * fix openpilot compiler * fix compile_efficientnet * method cache should fix tests * deal with duped functions
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: pip install -e .
|
||||
- name: Compile EfficientNet to C
|
||||
run: CLANG=1 GPU=1 python3 examples/compile_efficientnet.py > recognize.c
|
||||
run: CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
|
||||
- name: Compile C to native
|
||||
run: clang -O2 recognize.c -lm -o recognize
|
||||
- name: Test EfficientNet
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExe
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.runtime.cuda import CLBuffer
|
||||
from tinygrad.ast import ASTKernel
|
||||
from tinygrad.compiler.ast import ASTKernel
|
||||
|
||||
stream = cuda.Stream()
|
||||
|
||||
|
||||
@@ -4,31 +4,29 @@ from extra.utils import fetch
|
||||
import ast
|
||||
|
||||
def compile_net(run, special_names):
|
||||
# c header
|
||||
cprog = ["#include <stdio.h>", "#include <math.h>", "#define max(x,y) ((x>y)?x:y)"]
|
||||
|
||||
# functions that run the net
|
||||
functions = {}
|
||||
bufs = {}
|
||||
bufnum = 0
|
||||
statements = []
|
||||
bufs_to_save = {}
|
||||
for fxn,args in run.jit_cache:
|
||||
cprog.append(fxn.clprg.prg)
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(args):
|
||||
if i in fxn.bufs_to_delete: continue
|
||||
key = id(arg.cl)
|
||||
key = id(arg.raw())
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], len(arg.cl)//4)
|
||||
bufs[key] = (special_names[key], len(arg.raw()._buf))
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", len(arg.cl)//4)
|
||||
bufs[key] = (f"buf_{bufnum}", len(arg.raw()._buf))
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg.cl # if first usage of a buffer is not an output, and it's not a special name
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg.raw() # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append(f"{fxn.clprg.name}({', '.join(cargs)});")
|
||||
statements.append(f"{fxn.name}({', '.join(cargs)});")
|
||||
|
||||
return cprog, statements, bufs, bufs_to_save
|
||||
return functions, statements, bufs, bufs_to_save
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
@@ -44,21 +42,17 @@ if __name__ == "__main__":
|
||||
the_output = run(the_input)
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
|
||||
special_names = {id(the_input.lazydata.realized.raw()): "input", id(the_output.lazydata.realized.raw()): "outputs"}
|
||||
|
||||
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
# buffers (empty)
|
||||
cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] not in bufs_to_save]
|
||||
# c header
|
||||
cprog = ["#include <stdio.h>", "#include <math.h>", "#define max(x,y) ((x>y)?x:y)"]
|
||||
|
||||
# buffers (weights)
|
||||
# save the weights
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(memoryview(cl)[0:len(cl)//4])])
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
cprog.append(f"float *{name} = (float *){name}_data;")
|
||||
|
||||
# the net
|
||||
cprog += ["void net() {"] + statements + ["}"]
|
||||
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8')]
|
||||
@@ -69,6 +63,15 @@ if __name__ == "__main__":
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()]
|
||||
|
||||
# the functions
|
||||
cprog += list(functions.values())
|
||||
|
||||
# the net
|
||||
cprog += ["void net() {"] + statements + ["}"]
|
||||
|
||||
cprog += ["""
|
||||
int main(int argc, char* argv[]) {
|
||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||
@@ -101,5 +104,5 @@ int main(int argc, char* argv[]) {
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}"""]
|
||||
|
||||
# CLANG=1 GPU=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && time ./recognize docs/stable_diffusion_by_tinygrad.jpg
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/stable_diffusion_by_tinygrad.jpg
|
||||
print('\n'.join(cprog))
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
import sys
|
||||
np.set_printoptions(linewidth=160)
|
||||
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
|
||||
from tinygrad.llops.ops_llvm import LLVM, LLVMBuffer, int_const
|
||||
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
|
||||
from llvmlite import ir # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from tinygrad.runtime.metal import CLBuffer, CLProgram
|
||||
from tinygrad.runtime.ops_metal import CLBuffer, CLProgram
|
||||
|
||||
def benchmark(prog):
|
||||
e = prog()
|
||||
|
||||
@@ -3,7 +3,7 @@ import gc
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
def print_objects():
|
||||
|
||||
@@ -6,7 +6,7 @@ from enum import Enum
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from extra.lib_test_ast import test_ast
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import Dict, Type
|
||||
from tinygrad.ast import ASTKernel
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer
|
||||
from tinygrad.compiler.ast import ASTKernel
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
from tinygrad.ops import DeviceBuffer, map_buffers
|
||||
|
||||
in_test = False
|
||||
|
||||
@@ -4,12 +4,11 @@ import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.opencl import CL
|
||||
from tinygrad.llops.ops_gpu import CLProgram, CLImage, CLBuffer
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLImage, CLBuffer
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
|
||||
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
@@ -75,29 +74,29 @@ class Thneed:
|
||||
if o['arg_type'] == "image2d_t":
|
||||
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
|
||||
# hack: use a image1d since we can back that with a buffer
|
||||
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
# buffer isn't supported in image2d, copy buffer into image
|
||||
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
|
||||
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
|
||||
cl.enqueue_copy(q, arr, bufs[o['buffer_id']])
|
||||
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
|
||||
elif o['needs_load']:
|
||||
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
|
||||
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
|
||||
else:
|
||||
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
|
||||
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
|
||||
if o['arg_type'] == "image1d_t":
|
||||
assert not o['needs_load']
|
||||
assert not bufs_loaded[o['buffer_id']]
|
||||
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
|
||||
else:
|
||||
if 'data' in o:
|
||||
buf = cl.Buffer(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
|
||||
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
|
||||
else:
|
||||
# zero out buffers
|
||||
buf = cl.Buffer(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
|
||||
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
|
||||
|
||||
bufs[o['id']] = buf
|
||||
bufs_loaded[o['id']] = 'data' in o
|
||||
@@ -119,7 +118,7 @@ class Thneed:
|
||||
# load binaries
|
||||
for o in jdat['binaries']:
|
||||
nptr = ptr + o['length']
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], rename=False, binary=True)
|
||||
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True)
|
||||
ptr = nptr
|
||||
|
||||
# populate the cl_cache
|
||||
@@ -194,17 +193,17 @@ class Thneed:
|
||||
})
|
||||
if needs_load:
|
||||
data = np.empty(a.size//4, dtype=np.float32)
|
||||
cl.enqueue_copy(CL().cl_queue, data, a, is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, data, a, is_blocking=True)
|
||||
weights.append(data.tobytes())
|
||||
elif isinstance(a, cl.Image):
|
||||
needs_load = a in self.buffers_to_save
|
||||
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
|
||||
size = row_pitch * a.shape[1]
|
||||
# this is *2 if float16 and *4 if float32
|
||||
buf = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
|
||||
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
|
||||
|
||||
# zero out the buffer
|
||||
cl.enqueue_copy(CL().cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
|
||||
|
||||
CLProgram("from_image_strided", """
|
||||
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
|
||||
@@ -224,7 +223,7 @@ class Thneed:
|
||||
|
||||
if needs_load:
|
||||
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
|
||||
cl.enqueue_copy(CL().cl_queue, data, buf, is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, data, buf, is_blocking=True)
|
||||
if FLOAT16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
else:
|
||||
@@ -271,9 +270,9 @@ class Thneed:
|
||||
events = []
|
||||
st = time.monotonic()
|
||||
for prg, args in self.cl_cache:
|
||||
events.append(prg.clprg(CL().cl_queue, *args))
|
||||
events.append(prg.clprg(CL.cl_queue, *args))
|
||||
mt = time.monotonic()
|
||||
CL().cl_queue.finish()
|
||||
CL.cl_queue.finish()
|
||||
et = time.monotonic() - st
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
|
||||
|
||||
@@ -284,7 +283,7 @@ class Thneed:
|
||||
total_runtime = 0
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(prg.op_estimate)/runtime:9.2f} GFLOPS {prg.options} -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
if (DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3:
|
||||
print(prg.prg)
|
||||
total_runtime += runtime
|
||||
@@ -321,11 +320,11 @@ class Thneed:
|
||||
# 3 runs just in case
|
||||
for i in range(3):
|
||||
try:
|
||||
e = prg.clprg(CL().cl_queue, *args)
|
||||
e = prg.clprg(CL.cl_queue, *args)
|
||||
except (cl.LogicError, cl.RuntimeError):
|
||||
# INVALID_WORK_GROUP_SIZE
|
||||
continue
|
||||
CL().cl_queue.finish()
|
||||
CL.cl_queue.finish()
|
||||
runtime = e.profile.end - e.profile.start
|
||||
#print(runtime, args[0], args[1])
|
||||
runtimes.append((runtime, local_args))
|
||||
|
||||
@@ -19,7 +19,8 @@ import numpy as np
|
||||
import tinygrad.graph as graph
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
from tinygrad.runtime.opencl import CL
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
from extra.utils import fetch
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.tensor import Tensor
|
||||
@@ -64,13 +65,16 @@ def compile(dat, output_fn):
|
||||
cl_cache = []
|
||||
for prg,args in model_exec.jit_cache:
|
||||
real_clprg = prg.clprg
|
||||
setattr(real_clprg, "op_estimate", prg.op_estimate)
|
||||
used_ops += real_clprg.op_estimate
|
||||
# replace clprg with a fake program to log to cl_cache
|
||||
prg.clprg = lambda *args: cl_cache.append((real_clprg, args))
|
||||
prg.clprg = lambda *args, wait=False: cl_cache.append((real_clprg, list(args[0:2])+[x._cl for x in args[2:]]))
|
||||
prg(*args)
|
||||
# put it back
|
||||
prg.clprg = real_clprg
|
||||
|
||||
from extra.thneed import Thneed
|
||||
t = Thneed(cl_cache, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
|
||||
t = Thneed(cl_cache, {k:inputs[k].lazydata.realized.raw()._cl for k in inputs.keys()})
|
||||
|
||||
if getenv("OPTWG", 0):
|
||||
t.optimize_local_workgroup()
|
||||
@@ -84,7 +88,7 @@ def compile(dat, output_fn):
|
||||
|
||||
# confirm thneed found the right output
|
||||
thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, thneed_out, t.outputs[0], is_blocking=True)
|
||||
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
||||
|
||||
# testing is float32 only (fix this)
|
||||
@@ -102,11 +106,11 @@ def compile(dat, output_fn):
|
||||
|
||||
# try old thneed with a different input
|
||||
for k,v in t.inputs.items():
|
||||
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
t.run()
|
||||
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(old_thneed_out, t.outputs[0], is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, old_thneed_out, t.outputs[0], is_blocking=True)
|
||||
|
||||
# compare thneed (rerun) with torch
|
||||
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
@@ -119,11 +123,11 @@ def compile(dat, output_fn):
|
||||
|
||||
# inputs
|
||||
for k,v in nt.inputs.items():
|
||||
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
cl.enqueue_copy(CL.cl_queue, new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
|
||||
# compare torch to thneed
|
||||
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -14,7 +14,7 @@ setup(name='tinygrad',
|
||||
license='MIT',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
packages = ['tinygrad', 'tinygrad.llops', 'tinygrad.nn', 'tinygrad.runtime', 'tinygrad.shape'],
|
||||
packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.runtime', 'tinygrad.shape'],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
|
||||
3
sz.sh
3
sz.sh
@@ -1,2 +1,3 @@
|
||||
#!/bin/bash
|
||||
scc tinygrad --by-file
|
||||
# switched to cloc due to https://github.com/boyter/scc/issues/379
|
||||
cloc --by-file tinygrad/*
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.lib_test_ast import test_ast
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ if 'IMAGE' not in os.environ:
|
||||
os.environ['GPU'] = '1'
|
||||
os.environ['OPT'] = '2'
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.llops.ops_gpu import CLImage
|
||||
from tinygrad.runtime.ops_gpu import CLImage
|
||||
from tinygrad.nn import Conv2d
|
||||
Tensor.no_grad = True
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
conv(inp).realize()
|
||||
test = GlobalCounters.cache
|
||||
GlobalCounters.cache = None
|
||||
assert len(test) == 1, f"conv should only have one kernel {[x[0].clprg.name for x in test]}"
|
||||
print(test[0][0].clprg.prg)
|
||||
assert len(test) == 1, f"conv should only have one kernel {[x[0].name for x in test]}"
|
||||
print(test[0][0].prg)
|
||||
for arg in test[0][1]:
|
||||
print(arg.st)
|
||||
assert len(arg.st.views) == 1
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import networkx as nx # type: ignore
|
||||
import numpy as np
|
||||
from tinygrad.graph import G, log_op, prune_graph
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer
|
||||
from tinygrad.runtime.ops_cpu import CPUBuffer
|
||||
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
|
||||
|
||||
class TestGraph(unittest.TestCase):
|
||||
|
||||
@@ -12,15 +12,6 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored, getenv, DEBUG
|
||||
from tinygrad.jit import TinyJit
|
||||
METAL = getenv("METAL")
|
||||
try:
|
||||
from tinygrad.runtime.opencl import CL
|
||||
if METAL:
|
||||
from tinygrad.runtime.metal import sync
|
||||
else:
|
||||
def sync(): CL().cl_queue.finish()
|
||||
except ImportError:
|
||||
def sync(): pass
|
||||
|
||||
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
||||
@@ -43,22 +34,17 @@ def helper_test_speed(f1, *args):
|
||||
ret = None
|
||||
for _ in range(CNT):
|
||||
del ret
|
||||
GlobalCounters.global_ops = 0
|
||||
GlobalCounters.global_mem = 0
|
||||
args = [(x+1).realize() if isinstance(x,Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
|
||||
|
||||
# sync all before!
|
||||
sync()
|
||||
torch.zeros(1, device=torch_device).cpu()
|
||||
# force syncing
|
||||
[x.cpu().numpy() for x in args if x is not None]
|
||||
|
||||
GlobalCounters.global_ops = 0
|
||||
GlobalCounters.global_mem = 0
|
||||
if DEBUG >= 4: print("benchmark start")
|
||||
st = time.monotonic()
|
||||
ret = f1(*args)
|
||||
if isinstance(ret, Tensor) and ret.device in ["GPU"]:
|
||||
sync()
|
||||
if not isinstance(ret, Tensor) and torch_device != "cpu":
|
||||
# TODO: better way to sync?
|
||||
torch.zeros(1, device=torch_device).cpu()
|
||||
ret.cpu().numpy() # not ideal, it's copying. why is this so slow in tinygrad?
|
||||
et = (time.monotonic() - st) * 1000
|
||||
ets.append(et)
|
||||
if DEBUG >= 4: print("benchmark stop")
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
from enum import Enum, auto
|
||||
import itertools
|
||||
from typing import List, Tuple, Optional
|
||||
from tinygrad.helpers import prod, dedup, all_same
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops
|
||||
from typing import List, Tuple, Optional, Set
|
||||
from tinygrad.helpers import prod, dedup, all_same, DEBUG
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, GlobalCounters
|
||||
from tinygrad.shape import ShapeTracker, View, strides_for_shape
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, bufs_to_delete:Set[int]=set(), global_work_size:Optional[List[int]]=None, local_work_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_work_size, self.local_work_size, self.bufs_to_delete, self.op_estimate, self.mem_estimate = name, prg, global_work_size, local_work_size, bufs_to_delete, op_estimate, mem_estimate
|
||||
def build(self, runtime):
|
||||
self.clprg = runtime(self.name, self.prg)
|
||||
return self
|
||||
def __call__(self, *bufs):
|
||||
et = self.clprg(self.global_work_size, self.local_work_size, *[x.raw() for i,x in enumerate(bufs) if i not in self.bufs_to_delete], wait=DEBUG>=2)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 1:
|
||||
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(bufs)-len(self.bufs_to_delete):5d} kernels {str(self.global_work_size):18s} {str(self.local_work_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if DEBUG <= 1 else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return et
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
for i in range(len(shapes[0])):
|
||||
if not all_same([x[i] for x in shapes]):
|
||||
@@ -1,42 +1,35 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
import math
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set, Final, Callable
|
||||
from tinygrad.helpers import prod, DEBUG, IMAGE
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.ast import ASTKernel, Token, Types
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op
|
||||
from tinygrad.codegen.ast import ASTKernel, ASTRunner, Token, Types
|
||||
from tinygrad.shape.symbolic import Node, ModNode, DivNode, render_python
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import getenv, DEBUG, prod
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# TODO: select runtimes in a smarter way
|
||||
CUDA,METAL,CLANG = getenv("CUDA", 0), getenv("METAL", 0), getenv("CLANG", 0)
|
||||
if not CUDA and not METAL and not CLANG:
|
||||
from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
|
||||
else:
|
||||
class CLImage: # type: ignore
|
||||
def __init__(self, shape): raise NotImplementedError("current runtime doesn't support images")
|
||||
if CUDA: from tinygrad.runtime.cuda import CLBuffer, CLProgram # type: ignore
|
||||
elif METAL: from tinygrad.runtime.metal import CLBuffer, CLProgram # type: ignore
|
||||
elif CLANG: from tinygrad.runtime.clang import CLBuffer, CLProgram # type: ignore
|
||||
|
||||
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
|
||||
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
||||
|
||||
KOPT = getenv("KOPT", -1)
|
||||
PRINT_AST = getenv("PRINT_AST", "0")
|
||||
TEST_AST = getenv("TEST_AST", 0)
|
||||
class GPULanguage(NamedTuple):
|
||||
kernel_prefix : str = ""
|
||||
buffer_prefix : str = ""
|
||||
smem_prefix : str = ""
|
||||
barrier : str = ""
|
||||
gid : List[str] = []
|
||||
lid : List[str] = []
|
||||
extra_args : List[str] = []
|
||||
float4 : Optional[str] = None
|
||||
|
||||
class GPURunner:
|
||||
def __init__(self, clprg:CLProgram, bufs_to_delete:Set[int], global_work_size:List[int], local_work_size:Optional[List[int]]):
|
||||
self.clprg, self.global_work_size, self.local_work_size, self.bufs_to_delete = clprg, global_work_size, local_work_size, bufs_to_delete
|
||||
def __call__(self, *bufs):
|
||||
return self.clprg(self.global_work_size, self.local_work_size, *[x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete])
|
||||
class GPUCodegen(ASTKernel):
|
||||
lang : GPULanguage = GPULanguage()
|
||||
|
||||
# for renaming
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(lambda: -1)
|
||||
|
||||
class CLASTKernel(ASTKernel):
|
||||
code_for_op : Final[Dict[Op, str]] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "((float)1.0-A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
@@ -58,7 +51,7 @@ class CLASTKernel(ASTKernel):
|
||||
assert len(self.bufs[buf_index].st.views) == 1, "store has more than one view"
|
||||
|
||||
# all stores can merge, since they have one view and are valid
|
||||
should_upcast = not CLANG and self.buftokens[buf_index].can_float4()
|
||||
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4()
|
||||
|
||||
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
|
||||
did_store = set()
|
||||
@@ -68,14 +61,14 @@ class CLASTKernel(ASTKernel):
|
||||
assert valid.min == 1, "store must always be valid"
|
||||
if should_upcast:
|
||||
for j in range(4): did_store.add(o+j)
|
||||
v = Token(f"{CLProgram.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
|
||||
v = Token(f"{self.lang.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
|
||||
idxy, valid = self.sts[buf_index].expr_idxs(o)
|
||||
assert valid.min == 1, "store must always be valid"
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
if hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
|
||||
self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
elif v.typ == Types.FLOAT4:
|
||||
self.kernel.append(f"(({CLProgram.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
|
||||
self.kernel.append(f"(({self.lang.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
|
||||
else:
|
||||
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).render(render_cl)}] = {v.tok};\n")
|
||||
|
||||
@@ -87,7 +80,7 @@ class CLASTKernel(ASTKernel):
|
||||
val = self.bufs[buf_index]._backing[0]
|
||||
assert not math.isnan(val)
|
||||
const = Token(f"({val}f)", Types.FLOAT)
|
||||
should_upcast = not CLANG and const is None and self.buftokens[buf_index].can_float4()
|
||||
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4()
|
||||
tokens = []
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}"
|
||||
@@ -102,14 +95,14 @@ class CLASTKernel(ASTKernel):
|
||||
#print((idxy+j).render(), idxy_test.render(), valid.render(), valid_test.render(), can_merge)
|
||||
if const is not None:
|
||||
ldr = const
|
||||
elif isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
elif hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
elif should_upcast and can_merge:
|
||||
ldr = Token(f"(({CLProgram.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
||||
ldr = Token(f"(({self.lang.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
||||
else:
|
||||
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
|
||||
ldr = ldr if valid.min == 1 or (VALIDHACKS and isinstance(self.bufs[buf_index]._buf, CLImage)) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : 0.0f)", ldr.typ) if valid.max == 1 else Token("0.0f", ldr.typ))
|
||||
ldr = ldr if valid.min == 1 or (VALIDHACKS and hasattr(self.bufs[buf_index]._buf, "IMAGE")) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : 0.0f)", ldr.typ) if valid.max == 1 else Token("0.0f", ldr.typ))
|
||||
if const is not None:
|
||||
self.loaded_keys[(buf_index,o)] = ldr
|
||||
else:
|
||||
@@ -122,11 +115,11 @@ class CLASTKernel(ASTKernel):
|
||||
tokens.append(self.loaded_keys[(buf_index,o)])
|
||||
return tokens
|
||||
|
||||
def ast_parse(self, x:Union[GPUBuffer, LazyOp], acc:List[Token], do_reduce=False) -> List[Token]:
|
||||
def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]:
|
||||
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x))
|
||||
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
|
||||
values : List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
||||
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
|
||||
code = GPUCodegen.code_for_op[x.op] # TODO: replace this with a function
|
||||
if len(values) == 2:
|
||||
assert len(values[0]) == len(values[1]) and values[0][0].typ == values[1][0].typ, f"values mismatch {values}"
|
||||
return [Token(code.replace("A", a.tok).replace("B", b.tok), a.typ) for a,b in zip(values[0], values[1])]
|
||||
@@ -136,10 +129,10 @@ class CLASTKernel(ASTKernel):
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
# shove the axis to the end and remove
|
||||
if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs):
|
||||
if any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs):
|
||||
eb_valids = [True] * self.shape_len
|
||||
for i in range(len(self.bufs)):
|
||||
if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs:
|
||||
if hasattr(self.bufs[i]._buf, "IMAGE") and self.bufs[i] in self.earlybufs:
|
||||
valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)]
|
||||
eb_valids = [x and y for x,y in zip(eb_valids, valids)]
|
||||
assert any(eb_valids), f"invalid op with images {eb_valids}"
|
||||
@@ -158,18 +151,18 @@ class CLASTKernel(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# are we grouping?
|
||||
if not CLANG and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
if self.lang.float4 and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in ((([1024] if METAL else []) + [256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
|
||||
self.group_for_reduce.append(sz)
|
||||
break
|
||||
|
||||
# if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape
|
||||
if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and not self.buftokens[0].can_float4():
|
||||
if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf not in self.earlybufs) and not self.buftokens[0].can_float4():
|
||||
lb_valids = [True] * self.shape_len
|
||||
for i in range(len(self.bufs)):
|
||||
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
|
||||
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not hasattr(self.bufs[i]._buf, "IMAGE") or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
|
||||
lb_valids = [x and y for x,y in zip(lb_valids, valids)]
|
||||
assert any(lb_valids), f"invalid op with images {lb_valids}"
|
||||
lb_valid = lb_valids.index(True)
|
||||
@@ -192,7 +185,7 @@ class CLASTKernel(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# split to 4 float4s
|
||||
if self.buftokens[0].can_float4() and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
|
||||
if self.buftokens[0].can_float4() and any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
|
||||
xb_choices = []
|
||||
for i in range(self.first_reduce):
|
||||
if all(st.shape[i]%4 == 0 for st in self.sts):
|
||||
@@ -214,7 +207,7 @@ class CLASTKernel(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# use more opencl indexing
|
||||
if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage):
|
||||
if self.first_reduce == 2 and hasattr(self.bufs[0]._buf, "IMAGE"):
|
||||
base_shape = self.bufs[0]._base_shape
|
||||
if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]):
|
||||
if DEBUG >= 3: print("split opencl", base_shape, self.sts[0].shape)
|
||||
@@ -237,11 +230,11 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
|
||||
# group_for_reduce will have to be better first
|
||||
def codegen(self) -> Callable:
|
||||
def codegen(self) -> ASTRunner:
|
||||
self.process()
|
||||
self.upcast_in_mid_reduce = False
|
||||
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
|
||||
if KOPT == -1 or IMAGE == 2: self.hand_coded_optimizations()
|
||||
self.hand_coded_optimizations()
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
@@ -257,18 +250,17 @@ class CLASTKernel(ASTKernel):
|
||||
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
|
||||
|
||||
self.prekernel : Set[str] = set()
|
||||
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(isinstance(buf._buf, CLImage) for buf in self.bufs) else []
|
||||
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs) else []
|
||||
|
||||
# output_shape[-1] is get_global_id(0)
|
||||
if CLANG:
|
||||
if len(self.lang.gid) == 0:
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]
|
||||
else:
|
||||
MAX_OUTPUT_SHAPE = 3
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
|
||||
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {self.lang.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(len(self.lang.gid), len(self.output_shape))) if self.output_shape[-1-i] != 1]
|
||||
if len(self.output_shape) > len(self.lang.gid):
|
||||
# sometimes, there's more dimensions. compact all the dimensions into the first one
|
||||
# TODO: these compactions should be searchable
|
||||
final_dimension = len(self.output_shape)-MAX_OUTPUT_SHAPE
|
||||
final_dimension = len(self.output_shape)-len(self.lang.gid)
|
||||
for i in range(final_dimension-1, -1, -1):
|
||||
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
|
||||
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
|
||||
@@ -282,7 +274,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets()
|
||||
assert self.reduceopop is not None
|
||||
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceopop]};\n" for accumulator in accumulators]
|
||||
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {GPUCodegen.start_for_op[self.reduceopop]};\n" for accumulator in accumulators]
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
|
||||
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [accumulators[off] for off in acc_offsets], do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
|
||||
|
||||
@@ -292,9 +284,9 @@ class CLASTKernel(ASTKernel):
|
||||
assert lvalid.min == 1, "local buffer must always be valid"
|
||||
self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n")
|
||||
for i,acc in enumerate(accumulators):
|
||||
self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
|
||||
self.kernel.append(self.lang.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
|
||||
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
|
||||
self.kernel.append(CLProgram.barrier+"\n")
|
||||
self.kernel.append(self.lang.barrier+"\n")
|
||||
|
||||
if self.upcast_in_mid_reduce:
|
||||
assert len(self.group_for_reduce) == 2
|
||||
@@ -311,86 +303,31 @@ class CLASTKernel(ASTKernel):
|
||||
if self.upcast_in_mid_reduce:
|
||||
self.kernel.append(f'float4 ld = vload4(0, &temp{i}[mid*4]);\n')
|
||||
for j in range(4):
|
||||
self.kernel.append(CLASTKernel.code_for_op[self.reduceopop].replace('A', new_accumulators[i*4+j].tok).replace('B', f'ld.{"xyzw"[j]}')+";\n")
|
||||
self.kernel.append(GPUCodegen.code_for_op[self.reduceopop].replace('A', new_accumulators[i*4+j].tok).replace('B', f'ld.{"xyzw"[j]}')+";\n")
|
||||
else:
|
||||
self.kernel.append(CLASTKernel.code_for_op[self.reduceopop].replace('A', new_accumulators[i].tok).replace('B', f'temp{i}[mid]')+";\n")
|
||||
self.kernel.append(GPUCodegen.code_for_op[self.reduceopop].replace('A', new_accumulators[i].tok).replace('B', f'temp{i}[mid]')+";\n")
|
||||
self.kernel.append("}\n")
|
||||
accumulators = new_accumulators
|
||||
|
||||
# late ast
|
||||
self.store(0, self.ast_parse(self.ast, accumulators))
|
||||
if self.group_for_reduce: self.kernel.append("}")
|
||||
if CLANG: self.kernel += ["}"] * len(self.output_shape)
|
||||
if len(self.lang.gid) == 0: self.kernel += ["}"] * len(self.output_shape)
|
||||
self.kernel.append("\n}")
|
||||
|
||||
# kernel function definition
|
||||
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else (CLProgram.buffer_prefix+self.buftokens[i].decltype() + ("restrict" if CLANG else "")) for i,x in enumerate(self.bufs)]
|
||||
self.kernel = list(self.prekernel) + [f"{CLProgram.kernel_prefix} void {function_name}(",] + \
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + CLProgram.extra_args)] + \
|
||||
GPUCodegen.kernel_cnt[function_name] += 1
|
||||
if GPUCodegen.kernel_cnt[function_name]:
|
||||
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
|
||||
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
|
||||
self.kernel = list(self.prekernel) + [f"{self.lang.kernel_prefix} void {function_name}(",] + \
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + \
|
||||
[") {\n"] + self.kernel
|
||||
|
||||
# compile kernel
|
||||
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
|
||||
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
|
||||
return GPURunner(self.fxn, self.bufs_to_delete, self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None)
|
||||
|
||||
def print(self):
|
||||
super().print()
|
||||
for i in range(len(self.bufs)):
|
||||
print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.sts[i])
|
||||
print(self.fxn.prg)
|
||||
|
||||
class GPUBuffer(ExplicitExecAST):
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
|
||||
super().__init__(shape, hostbuf)
|
||||
self._buf : Optional[Union[CLImage, CLBuffer]] = hostbuf._buf if hostbuf is not None else None
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
# early copy in for large buffers
|
||||
if (self._backing is not None and self._backing.shape != (1,)) or force_create:
|
||||
self.cl
|
||||
|
||||
# TODO: refactor this to return self._buf and not import pyopencl
|
||||
@property
|
||||
def cl(self) -> Union[CLBuffer, CLImage]:
|
||||
if self._buf is None:
|
||||
self._buf = CLImage(self._base_shape) if (len(self._base_shape) == 3 and self._base_shape[2] == 4 and IMAGE >= 2) else CLBuffer(4*prod(self._base_shape))
|
||||
assert self._buf is not None
|
||||
if self._backing is not None:
|
||||
assert GlobalCounters.cache is None, f"can't copy in {self._backing.shape} while caching"
|
||||
self._buf.copyin(self._backing)
|
||||
self._backing = None
|
||||
return self._buf._cl
|
||||
|
||||
# TODO: we don't always need a hostbuf
|
||||
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
cl_buf = self.contiguous()
|
||||
cl_buf.cl # force buffer creation, happens if it's a backed buffer that hasn't been created yet
|
||||
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self.movement_op(MovementOps.RESHAPE, tuple(list(self.shape)+[1])), )))
|
||||
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
|
||||
cl_buf._buf.copyout(data)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None):
|
||||
k = CLASTKernel(ast, output_buffer)
|
||||
if KOPT > 0:
|
||||
from extra.kernel_search import apply_optimization
|
||||
apply_optimization(k, ast, max_interventions=KOPT)
|
||||
prg = k.codegen()
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
|
||||
prg(*k.bufs)
|
||||
if PRINT_AST == "1" or (hasattr(k, "fxn") and PRINT_AST == k.fxn.name):
|
||||
print(k.fxn.name)
|
||||
k.print()
|
||||
if TEST_AST:
|
||||
from extra.lib_test_ast import test_ast # type: ignore
|
||||
test_ast(k)
|
||||
return k.ret
|
||||
return ASTRunner(function_name, ' '.join(self.kernel), self.bufs_to_delete,
|
||||
self.output_shape[::-1] if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
|
||||
@@ -1,19 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import functools
|
||||
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.ops import LazyOp
|
||||
from tinygrad.ast import ASTKernel
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
|
||||
from tinygrad.runtime.llvm import LLVM, ir
|
||||
import functools, math
|
||||
from typing import ClassVar, List
|
||||
from llvmlite import ir # type: ignore
|
||||
from tinygrad.codegen.ast import ASTKernel, ASTRunner
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp
|
||||
from tinygrad.helpers import DEBUG, prod
|
||||
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
|
||||
|
||||
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
||||
|
||||
render_llvm = {
|
||||
Variable: lambda self,ops,ctx: self.expr,
|
||||
NumNode: lambda self,ops,ctx: int_const(self.b),
|
||||
@@ -26,7 +19,7 @@ render_llvm = {
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
||||
}
|
||||
|
||||
class LLVMBuffer(ExplicitExecAST):
|
||||
class LLVMCodegen(ASTKernel):
|
||||
op_lookup : ClassVar = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
@@ -46,41 +39,10 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
|
||||
}
|
||||
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None, force_create=False):
|
||||
super().__init__(shape, hostbuf)
|
||||
# TODO: force alignment?
|
||||
self._buf = (ctypes.c_float * (prod(self.shape)))() if hostbuf is None else hostbuf._buf
|
||||
#assert ctypes.addressof(self._buf) & 0x1F == 0
|
||||
def codegen(self):
|
||||
self.process()
|
||||
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
|
||||
|
||||
def __repr__(self): return f"LLVMBuffer {str(self.st)}"
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
x = x.astype(np.float32)
|
||||
ret = LLVMBuffer(x.shape)
|
||||
ctypes.memmove(ret._buf, x.ctypes.data, prod(ret.shape)*4)
|
||||
return ret
|
||||
|
||||
def toCPU(x): return np.ctypeslib.as_array(x.contiguous()._buf)[:prod(x.shape)].reshape(x.shape).copy()
|
||||
|
||||
func_cache : Dict[str, Any] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[LLVMBuffer]=None) -> LLVMBuffer:
|
||||
k = ASTKernel(ast, output_buffer)
|
||||
|
||||
# cached kernel
|
||||
if k.key in LLVMBuffer.func_cache:
|
||||
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
|
||||
return k.ret
|
||||
|
||||
# process if uncached
|
||||
k.process()
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(k.ast)
|
||||
print("old:", [x.shape for x in k.sts])
|
||||
print("old:", [x.views[-1].strides for x in k.sts])
|
||||
|
||||
# this stuff can't be hand coded
|
||||
kernel_output_axis : List[int] = []
|
||||
"""
|
||||
@@ -119,12 +81,12 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
"""
|
||||
|
||||
# the 4x4 need to go all the way at the end, even after reduce
|
||||
output_shape = k.sts[0].shape
|
||||
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
|
||||
output_shape = self.sts[0].shape
|
||||
full_shape_options = [x.shape for x in self.sts if x.shape != output_shape]
|
||||
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
|
||||
|
||||
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
|
||||
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
|
||||
kernel_output_dim = prod([self.sts[0].shape[a] for a in kernel_output_axis])
|
||||
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
|
||||
|
||||
def get_idxs(builder, idx, buf_index):
|
||||
@@ -143,7 +105,7 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
|
||||
# create llvm function
|
||||
module = ir.Module(name=__file__)
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(k.bufs))), name='exec')
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(self.bufs))), name='exec')
|
||||
|
||||
# force llvmlite to allow us to add function attribute then add the attribute
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
@@ -158,13 +120,13 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
loop_exit = loop_exit[::-1]
|
||||
|
||||
# add the buffer indexing
|
||||
idx_level = [[int_const(st.offset)] for st in k.sts]
|
||||
idx_level = [[int_const(st.offset)] for st in self.sts]
|
||||
for i in range(len(full_shape)):
|
||||
for j in range(len(k.bufs)):
|
||||
for j in range(len(self.bufs)):
|
||||
# stride
|
||||
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
|
||||
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
|
||||
si_ps = loop_exit[i+1].add(si, int_const(self.sts[j].views[-1].strides[i]))
|
||||
si.add_incoming(si_ps, loop_exit[i+1]._block)
|
||||
idx_level[j].append(si)
|
||||
|
||||
@@ -172,7 +134,7 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
def ast_parse(builder, x, level, reduce_result=None):
|
||||
if not isinstance(x, LazyOp):
|
||||
m = kernel_output_type(ir.Undefined)
|
||||
buf_index = k.bufs.index(x)
|
||||
buf_index = self.bufs.index(x)
|
||||
for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)):
|
||||
# first view is already implictly handled
|
||||
idx, valid = x.st._expr_idx(Variable(idx, 0, prod(x.st.shape)))
|
||||
@@ -195,12 +157,12 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
|
||||
m = kernel_output_type(ir.Undefined)
|
||||
if kernel_output_dim == 1:
|
||||
return LLVMBuffer.op_lookup[x.op](builder, *values)
|
||||
return LLVMCodegen.op_lookup[x.op](builder, *values)
|
||||
else:
|
||||
# TODO: this only has to be done for certain ops
|
||||
for i in range(kernel_output_dim):
|
||||
value = [builder.extract_element(v, int_const(i)) for v in values]
|
||||
element = LLVMBuffer.op_lookup[x.op](builder, *value)
|
||||
element = LLVMCodegen.op_lookup[x.op](builder, *value)
|
||||
m = builder.insert_element(m, element, int_const(i))
|
||||
return m
|
||||
|
||||
@@ -209,9 +171,9 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
|
||||
# do the early ast
|
||||
reduce_result = None
|
||||
if k.reduceop:
|
||||
reduce_input = ast_parse(loop_exit[-1], k.reduceop.src[0], -1)
|
||||
phis = [LLVMBuffer.start_for_op[k.reduceop.op]] # type: ignore
|
||||
if self.reduceop:
|
||||
reduce_input = ast_parse(loop_exit[-1], self.reduceop.src[0], -1)
|
||||
phis = [LLVMCodegen.start_for_op[self.reduceop.op]] # type: ignore
|
||||
if kernel_output_dim > 1:
|
||||
phis = [kernel_output_type(phis * kernel_output_dim)]
|
||||
for i in range(store_loop+1, len(loop_entry)):
|
||||
@@ -219,16 +181,16 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
val.add_incoming(phis[-1], loop_entry[i-1]._block)
|
||||
phis.append(val)
|
||||
|
||||
if k.reduceop.op == ReduceOps.SUM:
|
||||
if self.reduceop.op == ReduceOps.SUM:
|
||||
reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',))
|
||||
elif k.reduceop.op == ReduceOps.MAX:
|
||||
elif self.reduceop.op == ReduceOps.MAX:
|
||||
reduce_result = loop_exit[-1].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',))
|
||||
|
||||
for i,phi in enumerate(phis[1:]):
|
||||
phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block)
|
||||
|
||||
# do the late ast
|
||||
result = ast_parse(loop_exit[store_loop], k.ast, store_loop, reduce_result=reduce_result)
|
||||
result = ast_parse(loop_exit[store_loop], self.ast, store_loop, reduce_result=reduce_result)
|
||||
|
||||
# store result
|
||||
builder = loop_exit[store_loop]
|
||||
@@ -247,5 +209,5 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
|
||||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
|
||||
return k.ret
|
||||
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
|
||||
@@ -13,7 +13,7 @@ class TinyJit:
|
||||
self.input_replace : Dict[DeviceBuffer, Any]= {}
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
|
||||
if Device.DEFAULT not in ["GPU", "CLANG"]: return self.fxn(*args, **kwargs) # only jit on the GPU
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_tensors = {k:cast(DeviceBuffer, v.realize().lazydata.realized)._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_tensors) != 0, "no inputs to JIT"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
import sys, weakref, importlib, inspect
|
||||
import os, sys, weakref, importlib, inspect
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -14,21 +14,20 @@ sys.setrecursionlimit(10000)
|
||||
OPT = getenv("OPT", 2)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
|
||||
def get_buffer(name, base='tinygrad.llops'):
|
||||
def get_buffer(name, base='tinygrad.runtime'):
|
||||
try:
|
||||
return (name.upper(), [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0])
|
||||
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0]
|
||||
except Exception as e: # NOTE: this can't be put on one line due to mypy issue
|
||||
print(name, "backend not available", e, file=sys.stderr)
|
||||
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers : Dict[str, Type[DeviceBuffer]] = {x[0]:x[1] for x in [
|
||||
get_buffer('cpu'), get_buffer('gpu'), get_buffer('llvm'), get_buffer('torch'),
|
||||
get_buffer('triton', 'accel.triton')] if x is not None}
|
||||
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
|
||||
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
|
||||
self.DEFAULT : str = "CPU"
|
||||
for name in self._buffers:
|
||||
if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
|
||||
self.__setattr__(name, name)
|
||||
if self._buffers[name] is not None: self.__setattr__(name, name)
|
||||
Device = _Device()
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar
|
||||
import functools, operator
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -32,13 +32,35 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
|
||||
if x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
|
||||
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
class RawBuffer:
|
||||
def __init__(self, size): raise NotImplementedError("must be implemented")
|
||||
@classmethod
|
||||
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
|
||||
def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
|
||||
class RawBufferCopyIn(RawBuffer):
|
||||
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray):
|
||||
ret = cls(4*prod(x.shape))
|
||||
ret.copyin(x)
|
||||
return ret
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
size : int
|
||||
def copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
x = np.empty((self.size//4), dtype=np.float32)
|
||||
self.copyout(x)
|
||||
return x
|
||||
|
||||
# a placeholder class to extend by the exec classes
|
||||
class DeviceBuffer:
|
||||
class DeviceBuffer(RawBuffer):
|
||||
_buf: Any # underlying buffer
|
||||
shape: Tuple[int, ...]
|
||||
@staticmethod
|
||||
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
|
||||
def toCPU(self:DeviceBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
|
||||
|
||||
@@ -53,14 +75,14 @@ shape_fxn_for_op : Dict[Op, Callable] = {
|
||||
**{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops), op) for op in MovementOps}}
|
||||
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
fxn_for_op : ClassVar = shape_fxn_for_op
|
||||
# TODO: use generic types here to remove __init__ in specialized classes
|
||||
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
def __init__(self, lbuf:Any): self._buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
|
||||
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg))
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None):
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None):
|
||||
if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
|
||||
srcs = [cls.exec_ast(x) if isinstance(x, LazyOp) else x for x in ast.src]
|
||||
@@ -68,22 +90,57 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
|
||||
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
|
||||
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
|
||||
else: ret = cls(cls.fxn_for_op[ast.op](*([x.buf for x in srcs] + ([ast.arg] if ast.arg else []))))
|
||||
else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg else []))))
|
||||
if output_buffer is not None:
|
||||
assert output_buffer.shape == ret.shape
|
||||
output_buffer.buf = ret.buf
|
||||
output_buffer._buf = ret._buf
|
||||
return output_buffer
|
||||
else:
|
||||
return ret
|
||||
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(map_buffers({x:GenericExecAST(GenericShape(x.shape)) for x in get_buffers(ast)}, ast)).buf
|
||||
def get_lazyop_info(ast:LazyOp): return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape)) for x in get_buffers(ast)}, ast))._buf
|
||||
|
||||
# assumes you are using ShapeTracker
|
||||
# used in GPUBuffer and LLVMBuffer
|
||||
class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
|
||||
class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[CompiledBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape = self.st.shape
|
||||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._buf = hostbuf._buf if hostbuf is not None else None
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
if (self._backing is not None and self._backing.shape != (1,)) or force_create: self.raw()
|
||||
|
||||
# TODO: not GPUBuffer, get name of class
|
||||
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
|
||||
|
||||
raw_buffer_type : Type[RawBuffer]
|
||||
@classmethod
|
||||
def create_raw_buffer(cls, shape, backing) -> RawBuffer:
|
||||
assert backing is None or prod(shape) == prod(backing.shape), "backing has the wrong shape"
|
||||
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
|
||||
return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing)
|
||||
def raw(self) -> RawBuffer:
|
||||
if self._buf is None: self._buf = self.create_raw_buffer(self._base_shape, self._backing)
|
||||
self._backing = None
|
||||
return self._buf
|
||||
|
||||
@classmethod
|
||||
def fromCPU(cls, x:np.ndarray) -> CompiledBuffer: return cls(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
def toCPU(self) -> np.ndarray:
|
||||
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
|
||||
return self.contiguous().raw().toCPU().reshape(self.shape)
|
||||
|
||||
codegen_type : Any
|
||||
runtime_type : Type
|
||||
method_cache : Dict[str, Callable] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None):
|
||||
k = cls.codegen_type(ast, output_buffer)
|
||||
if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.runtime_type)
|
||||
prg = cls.method_cache[k.key]
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
|
||||
prg(*k.bufs)
|
||||
return k.ret
|
||||
|
||||
# universal for shape tracked
|
||||
def contiguous(self): return self if self.st.contiguous and prod(self._base_shape) == prod(self.shape) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
@@ -92,12 +149,12 @@ class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
class GlobalCounters:
|
||||
global_ops : ClassVar[int] = 0
|
||||
global_mem : ClassVar[int] = 0
|
||||
time_sum : ClassVar[int] = 0
|
||||
time_sum_s : ClassVar[float] = 0.0
|
||||
kernel_count : ClassVar[int] = 0
|
||||
mem_used : ClassVar[int] = 0 # NOTE: this is not reset
|
||||
cache : ClassVar[Optional[list]] = None
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
@staticmethod
|
||||
def log_kernel(op_estimate:int, mem_estimate:int):
|
||||
GlobalCounters.kernel_count += 1
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import ctypes
|
||||
import os
|
||||
import numpy as np
|
||||
import hashlib
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import List, Final, Dict
|
||||
from tinygrad.helpers import DEBUG
|
||||
import platform
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self._cl = (ctypes.c_float * (size))()
|
||||
def copyin(self, b:np.ndarray): ctypes.memmove(self._cl, b.ctypes.data, b.size*4)
|
||||
def copyout(self, a:np.ndarray):
|
||||
np.copyto(a, np.ctypeslib.as_array(self._cl)[:a.size].reshape(a.shape))
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix, buffer_prefix, smem_prefix, barrier = "", "", "", ""
|
||||
gid = [f"gid[{i}]" for i in range(3)]
|
||||
lid = [f"lid[{i}]" for i in range(3)]
|
||||
extra_args : List[str] = []
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
# TODO: remove name, factor out op_estimate and mem_estimate
|
||||
def __init__(self, name:str, prg:str, rename=True, op_estimate=0, mem_estimate=0):
|
||||
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
|
||||
CLProgram.kernel_cnt[name] += 1
|
||||
self.prg = prg.replace(f"{name}(", f"{self.name}(")
|
||||
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n" + prg
|
||||
if DEBUG >= 4: print(prg) # TODO: outside runtime!
|
||||
# TODO: is there a way to not write this to disk?
|
||||
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if OSX else 'so'}"
|
||||
if not os.path.exists(fn):
|
||||
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
|
||||
os.rename(fn+".tmp", fn)
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
def __call__(self, *args): self.fxn(*args[2:])
|
||||
@@ -1,37 +0,0 @@
|
||||
from typing import Optional, List
|
||||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import compile # type: ignore
|
||||
import numpy as np
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self._cl = cuda.mem_alloc(size)
|
||||
def copyin(self, b:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, b, stream)
|
||||
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self._cl)
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix = "__global__"
|
||||
buffer_prefix = ""
|
||||
smem_prefix = "__shared__ "
|
||||
barrier = "__syncthreads();"
|
||||
float4 = "make_float4"
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
|
||||
extra_args : List[str] = []
|
||||
def __init__(self, name:str, prg:str, binary=False, shared=0, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate, self.shared = name, op_estimate, mem_estimate, shared
|
||||
if DEBUG >= 4 and not binary: print("CUDA compile", prg)
|
||||
if not binary: prg = compile(prg, target="ptx", no_extern_c=True).decode('utf-8')
|
||||
if DEBUG >= 5: print(prg)
|
||||
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
def __call__(self, global_size, local_size, *args):
|
||||
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
|
||||
global_size = global_size + [1] * (3 - len(global_size))
|
||||
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
|
||||
global_size = [x//y for x,y in zip(global_size, local_size)]
|
||||
if DEBUG >= 2: print("CUDA launch", global_size, local_size)
|
||||
self.prg(*args, block=tuple(local_size), grid=tuple(global_size), shared=self.shared)
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
@@ -1,98 +0,0 @@
|
||||
# pip3 install pyobjc-framework-Metal pyobjc-framework-libdispatch
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
import subprocess, pathlib
|
||||
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
|
||||
device = Metal.MTLCreateSystemDefaultDevice()
|
||||
mtl_queue = device.newCommandQueue()
|
||||
mtl_buffers_in_flight : List[Any] = []
|
||||
|
||||
def sync():
|
||||
global mtl_buffers_in_flight
|
||||
for cbuf in mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
mtl_buffers_in_flight = []
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
def __del__(self): self._cl.release()
|
||||
|
||||
def copyin(self, b:np.ndarray):
|
||||
np.copyto(np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32), b.reshape(-1).data)
|
||||
|
||||
def toCPU(self):
|
||||
sync()
|
||||
return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
|
||||
|
||||
# TODO: remove copyout everywhere
|
||||
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
|
||||
|
||||
class CLProgram:
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel"
|
||||
buffer_prefix = "device "
|
||||
smem_prefix = "threadgroup "
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
||||
float4 = "float4"
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)]
|
||||
lid = [f"lid.{chr(120+i)}" for i in range(3)]
|
||||
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
|
||||
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
|
||||
if DEBUG >= 4: print("Metal compile", prg)
|
||||
if DEBUG >= 6: # dump llvm
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
|
||||
print(dis.decode('utf-8'))
|
||||
if METAL_XCODE:
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library, err = device.newLibraryWithData_error_(data, None)
|
||||
else:
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
self.library, err = device.newLibraryWithSource_options_error_(prg, options, None)
|
||||
assert err is None, str(err)
|
||||
self.fxn = self.library.newFunctionWithName_(name) #self.library.functionNames()[0]
|
||||
# hacks to disassemble shader
|
||||
if DEBUG >= 5:
|
||||
arc, err = device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
|
||||
assert err is None, str(err)
|
||||
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
|
||||
desc.setComputeFunction_(self.fxn)
|
||||
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
|
||||
assert err is None, str(err)
|
||||
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
|
||||
assert err is None, str(err)
|
||||
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
|
||||
import os
|
||||
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
|
||||
self.pipeline_state, err = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert err is None, str(err)
|
||||
|
||||
def __call__(self, global_size, local_size, *args):
|
||||
global_size += [1] * (3-len(global_size))
|
||||
if local_size is None: local_size = [32]
|
||||
local_size += [1] * (3-len(local_size))
|
||||
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
||||
command_buffer = mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(args):
|
||||
encoder.setBuffer_offset_atIndex_(a, 0, i)
|
||||
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
if DEBUG >= 2:
|
||||
command_buffer.waitUntilCompleted()
|
||||
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
print(f"METAL et {et*1e6:8.2f} us {self.name:28s} launch {str(global_size):18s} {local_size}")
|
||||
GlobalCounters.time_sum += et
|
||||
else:
|
||||
mtl_buffers_in_flight.append(command_buffer)
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return command_buffer
|
||||
@@ -1,101 +0,0 @@
|
||||
import functools, platform
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Dict, Optional, Tuple, List, ClassVar, Final
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
||||
OSX = platform.system() == "Darwin"
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
|
||||
CLCACHE = getenv("CLCACHE", 1)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class CL:
|
||||
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
|
||||
cl_ctx : ClassVar[Optional[cl.Context]] = None
|
||||
cl_queue : ClassVar[Optional[cl.CommandQueue]] = None
|
||||
def __init__(self) -> None:
|
||||
if CL.cl_queue is not None: return # already initted
|
||||
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: # settle for CPU
|
||||
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
|
||||
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
|
||||
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
|
||||
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
|
||||
|
||||
@staticmethod
|
||||
def enqueue_copy(a, b, is_blocking=False):
|
||||
if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
|
||||
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
|
||||
|
||||
class CLBuffer:
|
||||
def __init__(self, size):
|
||||
if DEBUG >= 4: print(f"allocate GPU Buffer {size}")
|
||||
if len(CL.BUFFER_CACHE[size]) > 0:
|
||||
self._cl = CL.BUFFER_CACHE[size].pop()
|
||||
else:
|
||||
# TODO: on GPU OOM, clear the cache
|
||||
self._cl = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size)
|
||||
GlobalCounters.mem_used += self._cl.size
|
||||
|
||||
def __del__(self):
|
||||
if CLCACHE: CL.BUFFER_CACHE[self._cl.size].append(self._cl)
|
||||
else: GlobalCounters.mem_used -= self._cl.size
|
||||
|
||||
def copyin(self, b:np.ndarray): CL.enqueue_copy(self._cl, b, False)
|
||||
def copyout(self, a:np.ndarray): CL.enqueue_copy(a, self._cl, True)
|
||||
|
||||
class CLImage:
|
||||
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
|
||||
def __init__(self, shape):
|
||||
self._cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
|
||||
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
|
||||
|
||||
def __del__(self):
|
||||
GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
|
||||
|
||||
def copyin(self, b:np.ndarray): raise NotImplementedError("no copyin for CLImage")
|
||||
def copyout(self, a:np.ndarray): raise NotImplementedError("no copyout for CLImage")
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class CLProgram:
|
||||
kernel_prefix = "__kernel"
|
||||
buffer_prefix = "__global "
|
||||
smem_prefix = "__local "
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
float4 = "(float4)"
|
||||
gid = [f'get_global_id({i})' for i in range(3)]
|
||||
lid = [f'get_local_id({i})' for i in range(3)]
|
||||
extra_args : List[str] = []
|
||||
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
|
||||
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
|
||||
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate
|
||||
self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore
|
||||
try:
|
||||
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
|
||||
except cl.RuntimeError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", self.prg)
|
||||
raise e
|
||||
if self.argdtypes is not None:
|
||||
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
|
||||
CLProgram.kernel_cnt[name] += 1
|
||||
def __call__(self, *args) -> cl.Event:
|
||||
if DEBUG >= 4: print(args[0], args[1], self.prg)
|
||||
# print the PTX for NVIDIA. TODO: probably broken for everything else
|
||||
if DEBUG >= 5 and not OSX: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8'))
|
||||
e = self.clprg(CL().cl_queue, *args)
|
||||
if DEBUG >= 2:
|
||||
assert CL.cl_queue is not None
|
||||
CL.cl_queue.finish()
|
||||
# NOTE: Profiling is not in ns in OS X, we multiply by a computed ratio
|
||||
et = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
GlobalCounters.time_sum += et
|
||||
if DEBUG >= 1:
|
||||
print(f"**CL** {GlobalCounters.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if DEBUG <= 1 else f"tm {et/1e3:9.2f}us/{GlobalCounters.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)"))
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return e
|
||||
37
tinygrad/runtime/ops_clang.py
Normal file
37
tinygrad/runtime/ops_clang.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import ctypes
|
||||
import os, time
|
||||
import numpy as np
|
||||
import hashlib
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Final, Dict
|
||||
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
|
||||
from tinygrad.codegen.gpu import GPUCodegen
|
||||
import platform
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
class RawMallocBuffer(RawBufferCopyIn):
|
||||
def __init__(self, size): self._buf = (ctypes.c_float * (size//4))()
|
||||
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
|
||||
def toCPU(self): return np.ctypeslib.as_array(self._buf)
|
||||
|
||||
class ClangProgram:
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
def __init__(self, name:str, prg:str):
|
||||
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n" + prg
|
||||
# TODO: is there a way to not write this to disk?
|
||||
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if OSX else 'so'}"
|
||||
if not os.path.exists(fn):
|
||||
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
|
||||
os.rename(fn+".tmp", fn)
|
||||
self.lib = ctypes.CDLL(fn)
|
||||
self.fxn = self.lib[name]
|
||||
def __call__(self, *args, wait=False):
|
||||
if wait: st = time.monotonic()
|
||||
self.fxn(*[x._buf for x in args[2:]])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
class ClangBuffer(CompiledBuffer):
|
||||
raw_buffer_type = RawMallocBuffer
|
||||
codegen_type = GPUCodegen # clang is the default
|
||||
runtime_type = ClangProgram
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import operator
|
||||
from typing import ClassVar, Callable, Dict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, GenericExecAST, Op
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, InterpretedBuffer, Op
|
||||
from tinygrad.helpers import shape_to_axis
|
||||
|
||||
base_fxn_for_op : Dict[Op, Callable] = {
|
||||
@@ -30,9 +30,9 @@ numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)
|
||||
}}
|
||||
|
||||
class CPUBuffer(GenericExecAST):
|
||||
class CPUBuffer(InterpretedBuffer):
|
||||
fxn_for_op : ClassVar = numpy_fxn_for_op
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return CPUBuffer(x)
|
||||
def toCPU(x): return x.buf
|
||||
def toCPU(x): return x._buf
|
||||
38
tinygrad/runtime/ops_cuda.py
Normal file
38
tinygrad/runtime/ops_cuda.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
|
||||
import pycuda.driver as cuda # type: ignore
|
||||
from pycuda.compiler import compile # type: ignore
|
||||
import numpy as np
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
class RawCUDABuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size)
|
||||
def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
|
||||
def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
if not binary: prg = compile(prg, target="ptx", no_extern_c=True).decode('utf-8')
|
||||
if DEBUG >= 5: print(prg)
|
||||
# TODO: name is wrong
|
||||
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False):
|
||||
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
|
||||
global_size = global_size + [1] * (3 - len(global_size))
|
||||
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
|
||||
global_size = [x//y for x,y in zip(global_size, local_size)]
|
||||
self.prg(*args, block=tuple(local_size), grid=tuple(global_size))
|
||||
|
||||
class CUDACodegen(GPUCodegen):
|
||||
lang = GPULanguage(
|
||||
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
|
||||
|
||||
class CUDABuffer(CompiledBuffer):
|
||||
raw_buffer_type = RawCUDABuffer
|
||||
codegen_type = CUDACodegen
|
||||
runtime_type = CUDAProgram
|
||||
92
tinygrad/runtime/ops_gpu.py
Normal file
92
tinygrad/runtime/ops_gpu.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
import platform, functools
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Dict, Optional, List, ClassVar, Final
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import IMAGE, DEBUG, getenv
|
||||
from tinygrad.ops import CompiledBuffer, GlobalCounters, RawBufferCopyInOut, RawBuffer
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
OSX = platform.system() == "Darwin"
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
CLCACHE = getenv("CLCACHE", 1)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
|
||||
class _CL:
|
||||
@functools.cached_property
|
||||
def cl_ctx(self) -> cl.Context:
|
||||
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) # settle for CPU
|
||||
if len(devices) > 1 or DEBUG >= 1: print(f"using {devices[getenv('CL_DEVICE', 0)]}")
|
||||
return cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
|
||||
|
||||
@functools.cached_property
|
||||
def cl_queue(self) -> cl.CommandQueue:
|
||||
return cl.CommandQueue(CL.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
|
||||
CL = _CL()
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
# TODO: this can be in RawBuffer generically
|
||||
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
if len(CLBuffer.BUFFER_CACHE[size]) > 0:
|
||||
self._cl = CLBuffer.BUFFER_CACHE[size].pop()
|
||||
else:
|
||||
# TODO: on GPU OOM, clear the cache
|
||||
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size)
|
||||
GlobalCounters.mem_used += self._cl.size
|
||||
|
||||
def __del__(self):
|
||||
if CLCACHE: CLBuffer.BUFFER_CACHE[self._cl.size].append(self._cl)
|
||||
else: GlobalCounters.mem_used -= self._cl.size
|
||||
|
||||
def copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False)
|
||||
def copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
|
||||
|
||||
class CLImage(RawBuffer):
|
||||
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
IMAGE : Final = True
|
||||
|
||||
def __init__(self, shape):
|
||||
self._cl = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
|
||||
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
|
||||
|
||||
def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None):
|
||||
self.name, self.argdtypes, self.clprogram = name, argdtypes, cl.Program(CL.cl_ctx, CL.cl_ctx.devices, [prg]) if binary else cl.Program(CL.cl_ctx, prg) # type: ignore
|
||||
try:
|
||||
self._clprg = self.clprogram.build()
|
||||
except cl.RuntimeError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
self.clprg = self._clprg.__getattr__(name)
|
||||
if DEBUG >= 5 and not OSX: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8')) # print the PTX for NVIDIA. TODO: probably broken for everything else
|
||||
if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes)
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._cl if isinstance(x, (CLBuffer, CLImage)) else x for x in bufs])
|
||||
if wait:
|
||||
CL.cl_queue.finish()
|
||||
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
|
||||
return None
|
||||
|
||||
class CLCodegen(GPUCodegen):
|
||||
lang = GPULanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)])
|
||||
|
||||
class GPUBuffer(CompiledBuffer):
|
||||
raw_buffer_type = CLBuffer
|
||||
# override this method for image
|
||||
@classmethod
|
||||
def create_raw_buffer(cls, shape, backing) -> RawBuffer:
|
||||
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and not backing: return CLImage(shape)
|
||||
else: return super().create_raw_buffer(shape, backing)
|
||||
codegen_type = CLCodegen
|
||||
runtime_type = CLProgram
|
||||
@@ -1,13 +1,12 @@
|
||||
import time, hashlib, ctypes
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import CompiledBuffer
|
||||
from tinygrad.runtime.ops_clang import RawMallocBuffer
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.ops import GlobalCounters
|
||||
import hashlib
|
||||
import time
|
||||
import ctypes
|
||||
from ctypes import CFUNCTYPE
|
||||
from tinygrad.codegen.llvm import LLVMCodegen
|
||||
|
||||
import llvmlite.binding as llvm # type: ignore
|
||||
from llvmlite import ir # type: ignore
|
||||
|
||||
class LLVM:
|
||||
target_machine : ClassVar[llvm.targets.TargetMachine] = None
|
||||
@@ -15,8 +14,7 @@ class LLVM:
|
||||
optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None
|
||||
|
||||
def __init__(self):
|
||||
if LLVM.engine is not None:
|
||||
return
|
||||
if LLVM.engine is not None: return
|
||||
llvm.initialize()
|
||||
llvm.initialize_native_target()
|
||||
llvm.initialize_native_asmprinter()
|
||||
@@ -46,42 +44,26 @@ class LLVM:
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
|
||||
|
||||
# TODO: LLVMProgram
|
||||
def exec(self, module:ir.Module, bufs, op_estimate=0, mem_estimate=0):
|
||||
module.triple = llvm.get_process_triple()
|
||||
module.data_layout = self.engine.target_data
|
||||
llvm_ir = str(module)
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(llvm_ir)
|
||||
|
||||
mod = llvm.parse_assembly(llvm_ir)
|
||||
mod.verify()
|
||||
LLVM.optimizer.run(mod)
|
||||
if DEBUG >= 4:
|
||||
print("Optimized IR:")
|
||||
print(str(mod))
|
||||
mod.name = hashlib.sha1(llvm_ir.encode('utf-8')).hexdigest()
|
||||
if DEBUG >= 3:
|
||||
print(LLVM.target_machine.emit_assembly(mod))
|
||||
LLVM.engine.add_module(mod)
|
||||
class LLVMProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
self.mod = llvm.parse_assembly(prg)
|
||||
self.mod.verify()
|
||||
LLVM().optimizer.run(self.mod)
|
||||
self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
|
||||
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
|
||||
LLVM.engine.add_module(self.mod)
|
||||
LLVM.engine.finalize_object()
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
|
||||
# call function (NOTE: if the types don't match, there's likely something wrong with the cache)
|
||||
#cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
||||
def __del__(self): LLVM.engine.remove_module(self.mod)
|
||||
|
||||
# why is this needed without the types. fixed tests below
|
||||
# LLVM=1 OPT=2 python3 test/test_ops.py TestOps.test_cat TestOps.test_multicat
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for x in bufs])(LLVM.engine.get_function_address('exec'))
|
||||
|
||||
st = time.monotonic()
|
||||
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for _ in bufs])(self.fxn)
|
||||
if wait: st = time.monotonic()
|
||||
cfunc(*[x._buf for x in bufs])
|
||||
et = time.monotonic() - st
|
||||
if DEBUG >= 1:
|
||||
print(f"**LLVM** time {et*1000:7.2f} ms OPs {op_estimate/1e6:7.2f}M -- {(op_estimate/1e9)/et:5.2f} GFLOPS -- {mem_estimate:10d} reads -- {(mem_estimate*4/1e9)/et:5.2f} GB/s")
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
# we are done
|
||||
LLVM.engine.remove_module(mod)
|
||||
return cfunc
|
||||
class LLVMBuffer(CompiledBuffer):
|
||||
raw_buffer_type = RawMallocBuffer
|
||||
codegen_type = LLVMCodegen
|
||||
runtime_type = LLVMProgram
|
||||
92
tinygrad/runtime/ops_metal.py
Normal file
92
tinygrad/runtime/ops_metal.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# pip3 install pyobjc-framework-Metal pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, functools
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
|
||||
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
|
||||
class _METAL:
|
||||
mtl_buffers_in_flight : List[Any] = []
|
||||
@functools.cached_property
|
||||
def device(self):
|
||||
return Metal.MTLCreateSystemDefaultDevice()
|
||||
@functools.cached_property
|
||||
def mtl_queue(self):
|
||||
return METAL.device.newCommandQueue()
|
||||
METAL = _METAL()
|
||||
|
||||
class RawMetalBuffer(RawBufferCopyIn):
|
||||
def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
def __del__(self): self._cl.release()
|
||||
def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
|
||||
def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data)
|
||||
def toCPU(self) -> np.ndarray:
|
||||
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
METAL.mtl_buffers_in_flight = []
|
||||
return self._as_np() # no copy!
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
if DEBUG >= 6: # dump llvm
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
|
||||
print(dis.decode('utf-8'))
|
||||
if METAL_XCODE:
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library, err = METAL.device.newLibraryWithData_error_(data, None)
|
||||
else:
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
self.library, err = METAL.device.newLibraryWithSource_options_error_(prg, options, None)
|
||||
assert err is None, str(err)
|
||||
self.fxn = self.library.newFunctionWithName_(name)
|
||||
# hacks to disassemble shader
|
||||
if DEBUG >= 5:
|
||||
arc, err = METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
|
||||
assert err is None, str(err)
|
||||
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
|
||||
desc.setComputeFunction_(self.fxn)
|
||||
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
|
||||
assert err is None, str(err)
|
||||
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
|
||||
assert err is None, str(err)
|
||||
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
|
||||
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
|
||||
self.pipeline_state, err = METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert err is None, str(err)
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
global_size += [1] * (3-len(global_size))
|
||||
if local_size is None: local_size = [32]
|
||||
local_size += [1] * (3-len(local_size))
|
||||
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._cl, 0, i)
|
||||
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalCodegen(GPUCodegen):
|
||||
lang = GPULanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4",
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
|
||||
|
||||
class MetalBuffer(CompiledBuffer):
|
||||
raw_buffer_type = RawMetalBuffer
|
||||
codegen_type = MetalCodegen
|
||||
runtime_type = MetalProgram
|
||||
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from typing import ClassVar, Final, Dict, Callable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, GenericExecAST, Op
|
||||
from typing import ClassVar, Dict, Callable
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, InterpretedBuffer, Op
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.llops.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
|
||||
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
||||
@@ -12,10 +12,9 @@ torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
}}
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(GenericExecAST):
|
||||
class TorchBuffer(InterpretedBuffer):
|
||||
fxn_for_op : ClassVar = torch_fxn_for_op
|
||||
SUPPORTS_SIMPLE_PADDING : Final = True
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
|
||||
def toCPU(x): return x.buf.cpu().numpy()
|
||||
def toCPU(x): return x._buf.cpu().numpy()
|
||||
Reference in New Issue
Block a user