Refactor ASTs (#622)

* ugh worst branch name

* compiler refactor continues

* scc -> cloc

* buf -> _buf

* finish _buf, and program -> runtime

* gpu is still working, clang isn't

* clang in new style

* ops_metal

* something broke it

* improve metal

* clean up tons of cl crap

* hack fix sync

* cleaner gpu

* gpu metal clang

* cleanups

* minor refactor

* GPUCodegen

* fix up LLVM

* blind CUDA refactor

* codegen / runtime

* keep ops naming

* linter passes

* woah, llvm was allocing 4x what it needed to

* bugfixes

* fix openpilot compiler

* fix compile_efficientnet

* method cache should fix tests

* deal with duped functions
This commit is contained in:
George Hotz
2023-03-01 18:57:29 -08:00
committed by GitHub
parent 5e41d5857c
commit bfcec234a2
34 changed files with 557 additions and 627 deletions

View File

@@ -72,7 +72,7 @@ jobs:
- name: Install Dependencies
run: pip install -e .
- name: Compile EfficientNet to C
run: CLANG=1 GPU=1 python3 examples/compile_efficientnet.py > recognize.c
run: CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
- name: Compile C to native
run: clang -O2 recognize.c -lm -o recognize
- name: Test EfficientNet

View File

@@ -14,7 +14,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExe
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import prod, DEBUG
from tinygrad.runtime.cuda import CLBuffer
from tinygrad.ast import ASTKernel
from tinygrad.compiler.ast import ASTKernel
stream = cuda.Stream()

View File

@@ -4,31 +4,29 @@ from extra.utils import fetch
import ast
def compile_net(run, special_names):
# c header
cprog = ["#include <stdio.h>", "#include <math.h>", "#define max(x,y) ((x>y)?x:y)"]
# functions that run the net
functions = {}
bufs = {}
bufnum = 0
statements = []
bufs_to_save = {}
for fxn,args in run.jit_cache:
cprog.append(fxn.clprg.prg)
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(args):
if i in fxn.bufs_to_delete: continue
key = id(arg.cl)
key = id(arg.raw())
if key not in bufs:
if key in special_names:
bufs[key] = (special_names[key], len(arg.cl)//4)
bufs[key] = (special_names[key], len(arg.raw()._buf))
else:
bufs[key] = (f"buf_{bufnum}", len(arg.cl)//4)
bufs[key] = (f"buf_{bufnum}", len(arg.raw()._buf))
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg.cl # if first usage of a buffer is not an output, and it's not a special name
if i > 0: bufs_to_save[bufs[key][0]] = arg.raw() # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
statements.append(f"{fxn.clprg.name}({', '.join(cargs)});")
statements.append(f"{fxn.name}({', '.join(cargs)});")
return cprog, statements, bufs, bufs_to_save
return functions, statements, bufs, bufs_to_save
if __name__ == "__main__":
model = EfficientNet(0)
@@ -44,21 +42,17 @@ if __name__ == "__main__":
the_output = run(the_input)
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
special_names = {id(the_input.lazydata.realized.raw()): "input", id(the_output.lazydata.realized.raw()): "outputs"}
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
# buffers (empty)
cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] not in bufs_to_save]
# c header
cprog = ["#include <stdio.h>", "#include <math.h>", "#define max(x,y) ((x>y)?x:y)"]
# buffers (weights)
# save the weights
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(memoryview(cl)[0:len(cl)//4])])
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog.append(f"float *{name} = (float *){name}_data;")
# the net
cprog += ["void net() {"] + statements + ["}"]
# image library!
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8')]
@@ -69,6 +63,15 @@ if __name__ == "__main__":
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
# buffers (empty + weights)
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()]
# the functions
cprog += list(functions.values())
# the net
cprog += ["void net() {"] + statements + ["}"]
cprog += ["""
int main(int argc, char* argv[]) {
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
@@ -101,5 +104,5 @@ int main(int argc, char* argv[]) {
else printf("%s\\n", lbls[best_idx]);
}"""]
# CLANG=1 GPU=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && time ./recognize docs/stable_diffusion_by_tinygrad.jpg
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/stable_diffusion_by_tinygrad.jpg
print('\n'.join(cprog))

View File

@@ -4,7 +4,7 @@ import time
import sys
np.set_printoptions(linewidth=160)
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
from tinygrad.llops.ops_llvm import LLVM, LLVMBuffer, int_const
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
from llvmlite import ir # type: ignore

View File

@@ -1,5 +1,5 @@
import numpy as np
from tinygrad.runtime.metal import CLBuffer, CLProgram
from tinygrad.runtime.ops_metal import CLBuffer, CLProgram
def benchmark(prog):
e = prog()

View File

@@ -3,7 +3,7 @@ import gc
from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.llops.ops_gpu import GPUBuffer
from tinygrad.runtime.ops_gpu import GPUBuffer
from tinygrad.ops import GlobalCounters
def print_objects():

View File

@@ -6,7 +6,7 @@ from enum import Enum
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
from tinygrad.helpers import getenv, DEBUG
from extra.lib_test_ast import test_ast

View File

@@ -1,8 +1,8 @@
import sys
import numpy as np
from typing import Dict, Type
from tinygrad.ast import ASTKernel
from tinygrad.llops.ops_cpu import CPUBuffer
from tinygrad.compiler.ast import ASTKernel
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.ops import DeviceBuffer, map_buffers
in_test = False

View File

@@ -4,12 +4,11 @@ import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.opencl import CL
from tinygrad.llops.ops_gpu import CLProgram, CLImage, CLBuffer
from tinygrad.runtime.ops_gpu import CLProgram, CLImage, CLBuffer
from tinygrad.helpers import prod, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
@@ -75,29 +74,29 @@ class Thneed:
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
# buffer isn't supported in image2d, copy buffer into image
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
cl.enqueue_copy(q, arr, bufs[o['buffer_id']])
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
assert not o['needs_load']
assert not bufs_loaded[o['buffer_id']]
buf = cl.Image(CL().cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
if 'data' in o:
buf = cl.Buffer(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
else:
# zero out buffers
buf = cl.Buffer(CL().cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
buf = cl.Buffer(CL.cl_ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
bufs[o['id']] = buf
bufs_loaded[o['id']] = 'data' in o
@@ -119,7 +118,7 @@ class Thneed:
# load binaries
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], rename=False, binary=True)
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True)
ptr = nptr
# populate the cl_cache
@@ -194,17 +193,17 @@ class Thneed:
})
if needs_load:
data = np.empty(a.size//4, dtype=np.float32)
cl.enqueue_copy(CL().cl_queue, data, a, is_blocking=True)
cl.enqueue_copy(CL.cl_queue, data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
needs_load = a in self.buffers_to_save
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
buf = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL().cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
cl.enqueue_copy(CL.cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
CLProgram("from_image_strided", """
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
@@ -224,7 +223,7 @@ class Thneed:
if needs_load:
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
cl.enqueue_copy(CL().cl_queue, data, buf, is_blocking=True)
cl.enqueue_copy(CL.cl_queue, data, buf, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())
else:
@@ -271,9 +270,9 @@ class Thneed:
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprg(CL().cl_queue, *args))
events.append(prg.clprg(CL.cl_queue, *args))
mt = time.monotonic()
CL().cl_queue.finish()
CL.cl_queue.finish()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
@@ -284,7 +283,7 @@ class Thneed:
total_runtime = 0
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(prg.op_estimate)/runtime:9.2f} GFLOPS {prg.options} -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
if (DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3:
print(prg.prg)
total_runtime += runtime
@@ -321,11 +320,11 @@ class Thneed:
# 3 runs just in case
for i in range(3):
try:
e = prg.clprg(CL().cl_queue, *args)
e = prg.clprg(CL.cl_queue, *args)
except (cl.LogicError, cl.RuntimeError):
# INVALID_WORK_GROUP_SIZE
continue
CL().cl_queue.finish()
CL.cl_queue.finish()
runtime = e.profile.end - e.profile.start
#print(runtime, args[0], args[1])
runtimes.append((runtime, local_args))

View File

@@ -19,7 +19,8 @@ import numpy as np
import tinygrad.graph as graph
from tinygrad.ops import GlobalCounters
from tinygrad.runtime.opencl import CL
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL
from extra.utils import fetch
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor
@@ -64,13 +65,16 @@ def compile(dat, output_fn):
cl_cache = []
for prg,args in model_exec.jit_cache:
real_clprg = prg.clprg
setattr(real_clprg, "op_estimate", prg.op_estimate)
used_ops += real_clprg.op_estimate
# replace clprg with a fake program to log to cl_cache
prg.clprg = lambda *args: cl_cache.append((real_clprg, args))
prg.clprg = lambda *args, wait=False: cl_cache.append((real_clprg, list(args[0:2])+[x._cl for x in args[2:]]))
prg(*args)
# put it back
prg.clprg = real_clprg
from extra.thneed import Thneed
t = Thneed(cl_cache, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
t = Thneed(cl_cache, {k:inputs[k].lazydata.realized.raw()._cl for k in inputs.keys()})
if getenv("OPTWG", 0):
t.optimize_local_workgroup()
@@ -84,7 +88,7 @@ def compile(dat, output_fn):
# confirm thneed found the right output
thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue, thneed_out, t.outputs[0], is_blocking=True)
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
# testing is float32 only (fix this)
@@ -102,11 +106,11 @@ def compile(dat, output_fn):
# try old thneed with a different input
for k,v in t.inputs.items():
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
t.run()
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
CL.enqueue_copy(old_thneed_out, t.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue, old_thneed_out, t.outputs[0], is_blocking=True)
# compare thneed (rerun) with torch
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
@@ -119,11 +123,11 @@ def compile(dat, output_fn):
# inputs
for k,v in nt.inputs.items():
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
cl.enqueue_copy(CL.cl_queue, v, new_np_inputs[k], is_blocking=True)
nt.run()
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
CL.enqueue_copy(new_thneed_out, nt.outputs[0], is_blocking=True)
cl.enqueue_copy(CL.cl_queue, new_thneed_out, nt.outputs[0], is_blocking=True)
# compare torch to thneed
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)

View File

@@ -14,7 +14,7 @@ setup(name='tinygrad',
license='MIT',
long_description=long_description,
long_description_content_type='text/markdown',
packages = ['tinygrad', 'tinygrad.llops', 'tinygrad.nn', 'tinygrad.runtime', 'tinygrad.shape'],
packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.runtime', 'tinygrad.shape'],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"

3
sz.sh
View File

@@ -1,2 +1,3 @@
#!/bin/bash
scc tinygrad --by-file
# switched to cloc due to https://github.com/boyter/scc/issues/379
cloc --by-file tinygrad/*

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
from tinygrad.ops import LazyOp, ReduceOps, BinaryOps, UnaryOps, MovementOps
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.runtime.ops_gpu import GPUBuffer, CLASTKernel
from tinygrad.helpers import getenv
from extra.lib_test_ast import test_ast

View File

@@ -7,7 +7,7 @@ if 'IMAGE' not in os.environ:
os.environ['GPU'] = '1'
os.environ['OPT'] = '2'
from tinygrad.tensor import Tensor
from tinygrad.llops.ops_gpu import CLImage
from tinygrad.runtime.ops_gpu import CLImage
from tinygrad.nn import Conv2d
Tensor.no_grad = True

View File

@@ -14,8 +14,8 @@ class TestConvShapetracker(unittest.TestCase):
conv(inp).realize()
test = GlobalCounters.cache
GlobalCounters.cache = None
assert len(test) == 1, f"conv should only have one kernel {[x[0].clprg.name for x in test]}"
print(test[0][0].clprg.prg)
assert len(test) == 1, f"conv should only have one kernel {[x[0].name for x in test]}"
print(test[0][0].prg)
for arg in test[0][1]:
print(arg.st)
assert len(arg.st.views) == 1

View File

@@ -3,7 +3,7 @@ import unittest
import networkx as nx # type: ignore
import numpy as np
from tinygrad.graph import G, log_op, prune_graph
from tinygrad.llops.ops_cpu import CPUBuffer
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
class TestGraph(unittest.TestCase):

View File

@@ -12,15 +12,6 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, DEBUG
from tinygrad.jit import TinyJit
METAL = getenv("METAL")
try:
from tinygrad.runtime.opencl import CL
if METAL:
from tinygrad.runtime.metal import sync
else:
def sync(): CL().cl_queue.finish()
except ImportError:
def sync(): pass
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
@@ -43,22 +34,17 @@ def helper_test_speed(f1, *args):
ret = None
for _ in range(CNT):
del ret
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
args = [(x+1).realize() if isinstance(x,Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
# sync all before!
sync()
torch.zeros(1, device=torch_device).cpu()
# force syncing
[x.cpu().numpy() for x in args if x is not None]
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
if DEBUG >= 4: print("benchmark start")
st = time.monotonic()
ret = f1(*args)
if isinstance(ret, Tensor) and ret.device in ["GPU"]:
sync()
if not isinstance(ret, Tensor) and torch_device != "cpu":
# TODO: better way to sync?
torch.zeros(1, device=torch_device).cpu()
ret.cpu().numpy() # not ideal, it's copying. why is this so slow in tinygrad?
et = (time.monotonic() - st) * 1000
ets.append(et)
if DEBUG >= 4: print("benchmark stop")

View File

@@ -1,10 +1,26 @@
from enum import Enum, auto
import itertools
from typing import List, Tuple, Optional
from tinygrad.helpers import prod, dedup, all_same
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops
from typing import List, Tuple, Optional, Set
from tinygrad.helpers import prod, dedup, all_same, DEBUG
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops, GlobalCounters
from tinygrad.shape import ShapeTracker, View, strides_for_shape
class ASTRunner:
def __init__(self, name, prg, bufs_to_delete:Set[int]=set(), global_work_size:Optional[List[int]]=None, local_work_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_work_size, self.local_work_size, self.bufs_to_delete, self.op_estimate, self.mem_estimate = name, prg, global_work_size, local_work_size, bufs_to_delete, op_estimate, mem_estimate
def build(self, runtime):
self.clprg = runtime(self.name, self.prg)
return self
def __call__(self, *bufs):
et = self.clprg(self.global_work_size, self.local_work_size, *[x.raw() for i,x in enumerate(bufs) if i not in self.bufs_to_delete], wait=DEBUG>=2)
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 1:
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(bufs)-len(self.bufs_to_delete):5d} kernels {str(self.global_work_size):18s} {str(self.local_work_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if DEBUG <= 1 else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
return et
def get_first_reduce(shapes):
for i in range(len(shapes[0])):
if not all_same([x[i] for x in shapes]):

View File

@@ -1,42 +1,35 @@
from __future__ import annotations
import numpy as np
import math
from typing import List, Tuple, Optional, Dict, Union, Set, Final, Callable
from tinygrad.helpers import prod, DEBUG, IMAGE
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.ast import ASTKernel, Token, Types
from tinygrad.shape import ShapeTracker
from collections import defaultdict
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op
from tinygrad.codegen.ast import ASTKernel, ASTRunner, Token, Types
from tinygrad.shape.symbolic import Node, ModNode, DivNode, render_python
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import getenv, DEBUG, prod
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
from tinygrad.helpers import getenv
# TODO: select runtimes in a smarter way
CUDA,METAL,CLANG = getenv("CUDA", 0), getenv("METAL", 0), getenv("CLANG", 0)
if not CUDA and not METAL and not CLANG:
from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
else:
class CLImage: # type: ignore
def __init__(self, shape): raise NotImplementedError("current runtime doesn't support images")
if CUDA: from tinygrad.runtime.cuda import CLBuffer, CLProgram # type: ignore
elif METAL: from tinygrad.runtime.metal import CLBuffer, CLProgram # type: ignore
elif CLANG: from tinygrad.runtime.clang import CLBuffer, CLProgram # type: ignore
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
KOPT = getenv("KOPT", -1)
PRINT_AST = getenv("PRINT_AST", "0")
TEST_AST = getenv("TEST_AST", 0)
class GPULanguage(NamedTuple):
kernel_prefix : str = ""
buffer_prefix : str = ""
smem_prefix : str = ""
barrier : str = ""
gid : List[str] = []
lid : List[str] = []
extra_args : List[str] = []
float4 : Optional[str] = None
class GPURunner:
def __init__(self, clprg:CLProgram, bufs_to_delete:Set[int], global_work_size:List[int], local_work_size:Optional[List[int]]):
self.clprg, self.global_work_size, self.local_work_size, self.bufs_to_delete = clprg, global_work_size, local_work_size, bufs_to_delete
def __call__(self, *bufs):
return self.clprg(self.global_work_size, self.local_work_size, *[x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete])
class GPUCodegen(ASTKernel):
lang : GPULanguage = GPULanguage()
# for renaming
kernel_cnt : Final[Dict[str, int]] = defaultdict(lambda: -1)
class CLASTKernel(ASTKernel):
code_for_op : Final[Dict[Op, str]] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "((float)1.0-A)",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
@@ -58,7 +51,7 @@ class CLASTKernel(ASTKernel):
assert len(self.bufs[buf_index].st.views) == 1, "store has more than one view"
# all stores can merge, since they have one view and are valid
should_upcast = not CLANG and self.buftokens[buf_index].can_float4()
should_upcast = self.lang.float4 and self.buftokens[buf_index].can_float4()
to_store = {o:v for o,v in zip(self.buftokens[buf_index].offsets(), value)}
did_store = set()
@@ -68,14 +61,14 @@ class CLASTKernel(ASTKernel):
assert valid.min == 1, "store must always be valid"
if should_upcast:
for j in range(4): did_store.add(o+j)
v = Token(f"{CLProgram.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
v = Token(f"{self.lang.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
idxy, valid = self.sts[buf_index].expr_idxs(o)
assert valid.min == 1, "store must always be valid"
if isinstance(self.bufs[buf_index]._buf, CLImage):
if hasattr(self.bufs[buf_index]._buf, "IMAGE"):
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
elif v.typ == Types.FLOAT4:
self.kernel.append(f"(({CLProgram.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
self.kernel.append(f"(({self.lang.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
else:
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).render(render_cl)}] = {v.tok};\n")
@@ -87,7 +80,7 @@ class CLASTKernel(ASTKernel):
val = self.bufs[buf_index]._backing[0]
assert not math.isnan(val)
const = Token(f"({val}f)", Types.FLOAT)
should_upcast = not CLANG and const is None and self.buftokens[buf_index].can_float4()
should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4()
tokens = []
for o in self.buftokens[buf_index].offsets():
key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}"
@@ -102,14 +95,14 @@ class CLASTKernel(ASTKernel):
#print((idxy+j).render(), idxy_test.render(), valid.render(), valid_test.render(), can_merge)
if const is not None:
ldr = const
elif isinstance(self.bufs[buf_index]._buf, CLImage):
elif hasattr(self.bufs[buf_index]._buf, "IMAGE"):
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
elif should_upcast and can_merge:
ldr = Token(f"(({CLProgram.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
ldr = Token(f"(({self.lang.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{idxy.render(render_cl)}]", Types.FLOAT)
ldr = ldr if valid.min == 1 or (VALIDHACKS and isinstance(self.bufs[buf_index]._buf, CLImage)) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : 0.0f)", ldr.typ) if valid.max == 1 else Token("0.0f", ldr.typ))
ldr = ldr if valid.min == 1 or (VALIDHACKS and hasattr(self.bufs[buf_index]._buf, "IMAGE")) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : 0.0f)", ldr.typ) if valid.max == 1 else Token("0.0f", ldr.typ))
if const is not None:
self.loaded_keys[(buf_index,o)] = ldr
else:
@@ -122,11 +115,11 @@ class CLASTKernel(ASTKernel):
tokens.append(self.loaded_keys[(buf_index,o)])
return tokens
def ast_parse(self, x:Union[GPUBuffer, LazyOp], acc:List[Token], do_reduce=False) -> List[Token]:
def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]:
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x))
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
values : List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
code = GPUCodegen.code_for_op[x.op] # TODO: replace this with a function
if len(values) == 2:
assert len(values[0]) == len(values[1]) and values[0][0].typ == values[1][0].typ, f"values mismatch {values}"
return [Token(code.replace("A", a.tok).replace("B", b.tok), a.typ) for a,b in zip(values[0], values[1])]
@@ -136,10 +129,10 @@ class CLASTKernel(ASTKernel):
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
# shove the axis to the end and remove
if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs):
if any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs):
eb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs:
if hasattr(self.bufs[i]._buf, "IMAGE") and self.bufs[i] in self.earlybufs:
valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)]
eb_valids = [x and y for x,y in zip(eb_valids, valids)]
assert any(eb_valids), f"invalid op with images {eb_valids}"
@@ -158,18 +151,18 @@ class CLASTKernel(ASTKernel):
self.simplify_ones()
# are we grouping?
if not CLANG and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
if self.lang.float4 and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ((([1024] if METAL else []) + [256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
self.group_for_reduce.append(sz)
break
# if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape
if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and not self.buftokens[0].can_float4():
if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf not in self.earlybufs) and not self.buftokens[0].can_float4():
lb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not hasattr(self.bufs[i]._buf, "IMAGE") or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
lb_valids = [x and y for x,y in zip(lb_valids, valids)]
assert any(lb_valids), f"invalid op with images {lb_valids}"
lb_valid = lb_valids.index(True)
@@ -192,7 +185,7 @@ class CLASTKernel(ASTKernel):
self.simplify_ones()
# split to 4 float4s
if self.buftokens[0].can_float4() and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
if self.buftokens[0].can_float4() and any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
xb_choices = []
for i in range(self.first_reduce):
if all(st.shape[i]%4 == 0 for st in self.sts):
@@ -214,7 +207,7 @@ class CLASTKernel(ASTKernel):
self.simplify_ones()
# use more opencl indexing
if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage):
if self.first_reduce == 2 and hasattr(self.bufs[0]._buf, "IMAGE"):
base_shape = self.bufs[0]._base_shape
if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]):
if DEBUG >= 3: print("split opencl", base_shape, self.sts[0].shape)
@@ -237,11 +230,11 @@ class CLASTKernel(ASTKernel):
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
# group_for_reduce will have to be better first
def codegen(self) -> Callable:
def codegen(self) -> ASTRunner:
self.process()
self.upcast_in_mid_reduce = False
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
if KOPT == -1 or IMAGE == 2: self.hand_coded_optimizations()
self.hand_coded_optimizations()
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
@@ -257,18 +250,17 @@ class CLASTKernel(ASTKernel):
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
self.prekernel : Set[str] = set()
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(isinstance(buf._buf, CLImage) for buf in self.bufs) else []
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs) else []
# output_shape[-1] is get_global_id(0)
if CLANG:
if len(self.lang.gid) == 0:
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]
else:
MAX_OUTPUT_SHAPE = 3
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {self.lang.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(len(self.lang.gid), len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > len(self.lang.gid):
# sometimes, there's more dimensions. compact all the dimensions into the first one
# TODO: these compactions should be searchable
final_dimension = len(self.output_shape)-MAX_OUTPUT_SHAPE
final_dimension = len(self.output_shape)-len(self.lang.gid)
for i in range(final_dimension-1, -1, -1):
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
@@ -282,7 +274,7 @@ class CLASTKernel(ASTKernel):
acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets()
assert self.reduceopop is not None
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceopop]};\n" for accumulator in accumulators]
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {GPUCodegen.start_for_op[self.reduceopop]};\n" for accumulator in accumulators]
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [accumulators[off] for off in acc_offsets], do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
@@ -292,9 +284,9 @@ class CLASTKernel(ASTKernel):
assert lvalid.min == 1, "local buffer must always be valid"
self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n")
for i,acc in enumerate(accumulators):
self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
self.kernel.append(self.lang.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
self.kernel.append(CLProgram.barrier+"\n")
self.kernel.append(self.lang.barrier+"\n")
if self.upcast_in_mid_reduce:
assert len(self.group_for_reduce) == 2
@@ -311,86 +303,31 @@ class CLASTKernel(ASTKernel):
if self.upcast_in_mid_reduce:
self.kernel.append(f'float4 ld = vload4(0, &temp{i}[mid*4]);\n')
for j in range(4):
self.kernel.append(CLASTKernel.code_for_op[self.reduceopop].replace('A', new_accumulators[i*4+j].tok).replace('B', f'ld.{"xyzw"[j]}')+";\n")
self.kernel.append(GPUCodegen.code_for_op[self.reduceopop].replace('A', new_accumulators[i*4+j].tok).replace('B', f'ld.{"xyzw"[j]}')+";\n")
else:
self.kernel.append(CLASTKernel.code_for_op[self.reduceopop].replace('A', new_accumulators[i].tok).replace('B', f'temp{i}[mid]')+";\n")
self.kernel.append(GPUCodegen.code_for_op[self.reduceopop].replace('A', new_accumulators[i].tok).replace('B', f'temp{i}[mid]')+";\n")
self.kernel.append("}\n")
accumulators = new_accumulators
# late ast
self.store(0, self.ast_parse(self.ast, accumulators))
if self.group_for_reduce: self.kernel.append("}")
if CLANG: self.kernel += ["}"] * len(self.output_shape)
if len(self.lang.gid) == 0: self.kernel += ["}"] * len(self.output_shape)
self.kernel.append("\n}")
# kernel function definition
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else (CLProgram.buffer_prefix+self.buftokens[i].decltype() + ("restrict" if CLANG else "")) for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{CLProgram.kernel_prefix} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + CLProgram.extra_args)] + \
GPUCodegen.kernel_cnt[function_name] += 1
if GPUCodegen.kernel_cnt[function_name]:
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{self.lang.kernel_prefix} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + \
[") {\n"] + self.kernel
# compile kernel
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
return GPURunner(self.fxn, self.bufs_to_delete, self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None)
def print(self):
super().print()
for i in range(len(self.bufs)):
print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.sts[i])
print(self.fxn.prg)
class GPUBuffer(ExplicitExecAST):
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
super().__init__(shape, hostbuf)
self._buf : Optional[Union[CLImage, CLBuffer]] = hostbuf._buf if hostbuf is not None else None
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
# early copy in for large buffers
if (self._backing is not None and self._backing.shape != (1,)) or force_create:
self.cl
# TODO: refactor this to return self._buf and not import pyopencl
@property
def cl(self) -> Union[CLBuffer, CLImage]:
if self._buf is None:
self._buf = CLImage(self._base_shape) if (len(self._base_shape) == 3 and self._base_shape[2] == 4 and IMAGE >= 2) else CLBuffer(4*prod(self._base_shape))
assert self._buf is not None
if self._backing is not None:
assert GlobalCounters.cache is None, f"can't copy in {self._backing.shape} while caching"
self._buf.copyin(self._backing)
self._backing = None
return self._buf._cl
# TODO: we don't always need a hostbuf
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
@staticmethod
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
def toCPU(self) -> np.ndarray:
cl_buf = self.contiguous()
cl_buf.cl # force buffer creation, happens if it's a backed buffer that hasn't been created yet
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self.movement_op(MovementOps.RESHAPE, tuple(list(self.shape)+[1])), )))
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
data = np.empty(self.shape, dtype=np.float32)
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
cl_buf._buf.copyout(data)
return data
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None):
k = CLASTKernel(ast, output_buffer)
if KOPT > 0:
from extra.kernel_search import apply_optimization
apply_optimization(k, ast, max_interventions=KOPT)
prg = k.codegen()
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
prg(*k.bufs)
if PRINT_AST == "1" or (hasattr(k, "fxn") and PRINT_AST == k.fxn.name):
print(k.fxn.name)
k.print()
if TEST_AST:
from extra.lib_test_ast import test_ast # type: ignore
test_ast(k)
return k.ret
return ASTRunner(function_name, ' '.join(self.kernel), self.bufs_to_delete,
self.output_shape[::-1] if len(self.output_shape) > 0 else [1],
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))

View File

@@ -1,19 +1,12 @@
from __future__ import annotations
import math
import functools
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape import ShapeTracker
from tinygrad.ops import LazyOp
from tinygrad.ast import ASTKernel
import ctypes
import numpy as np
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
from tinygrad.runtime.llvm import LLVM, ir
import functools, math
from typing import ClassVar, List
from llvmlite import ir # type: ignore
from tinygrad.codegen.ast import ASTKernel, ASTRunner
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp
from tinygrad.helpers import DEBUG, prod
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
def int_const(x): return ir.Constant(ir.IntType(64), x)
render_llvm = {
Variable: lambda self,ops,ctx: self.expr,
NumNode: lambda self,ops,ctx: int_const(self.b),
@@ -26,7 +19,7 @@ render_llvm = {
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
}
class LLVMBuffer(ExplicitExecAST):
class LLVMCodegen(ASTKernel):
op_lookup : ClassVar = {
UnaryOps.NOOP: lambda builder,x: x,
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
@@ -46,41 +39,10 @@ class LLVMBuffer(ExplicitExecAST):
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
}
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None, force_create=False):
super().__init__(shape, hostbuf)
# TODO: force alignment?
self._buf = (ctypes.c_float * (prod(self.shape)))() if hostbuf is None else hostbuf._buf
#assert ctypes.addressof(self._buf) & 0x1F == 0
def codegen(self):
self.process()
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
def __repr__(self): return f"LLVMBuffer {str(self.st)}"
@staticmethod
def fromCPU(x):
x = x.astype(np.float32)
ret = LLVMBuffer(x.shape)
ctypes.memmove(ret._buf, x.ctypes.data, prod(ret.shape)*4)
return ret
def toCPU(x): return np.ctypeslib.as_array(x.contiguous()._buf)[:prod(x.shape)].reshape(x.shape).copy()
func_cache : Dict[str, Any] = {}
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[LLVMBuffer]=None) -> LLVMBuffer:
k = ASTKernel(ast, output_buffer)
# cached kernel
if k.key in LLVMBuffer.func_cache:
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
return k.ret
# process if uncached
k.process()
if DEBUG >= 2:
print(k.ast)
print("old:", [x.shape for x in k.sts])
print("old:", [x.views[-1].strides for x in k.sts])
# this stuff can't be hand coded
kernel_output_axis : List[int] = []
"""
@@ -119,12 +81,12 @@ class LLVMBuffer(ExplicitExecAST):
"""
# the 4x4 need to go all the way at the end, even after reduce
output_shape = k.sts[0].shape
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
output_shape = self.sts[0].shape
full_shape_options = [x.shape for x in self.sts if x.shape != output_shape]
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
kernel_output_dim = prod([self.sts[0].shape[a] for a in kernel_output_axis])
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
def get_idxs(builder, idx, buf_index):
@@ -143,7 +105,7 @@ class LLVMBuffer(ExplicitExecAST):
# create llvm function
module = ir.Module(name=__file__)
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(k.bufs))), name='exec')
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.FloatType().as_pointer()]*(len(self.bufs))), name='exec')
# force llvmlite to allow us to add function attribute then add the attribute
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
@@ -158,13 +120,13 @@ class LLVMBuffer(ExplicitExecAST):
loop_exit = loop_exit[::-1]
# add the buffer indexing
idx_level = [[int_const(st.offset)] for st in k.sts]
idx_level = [[int_const(st.offset)] for st in self.sts]
for i in range(len(full_shape)):
for j in range(len(k.bufs)):
for j in range(len(self.bufs)):
# stride
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
si_ps = loop_exit[i+1].add(si, int_const(self.sts[j].views[-1].strides[i]))
si.add_incoming(si_ps, loop_exit[i+1]._block)
idx_level[j].append(si)
@@ -172,7 +134,7 @@ class LLVMBuffer(ExplicitExecAST):
def ast_parse(builder, x, level, reduce_result=None):
if not isinstance(x, LazyOp):
m = kernel_output_type(ir.Undefined)
buf_index = k.bufs.index(x)
buf_index = self.bufs.index(x)
for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)):
# first view is already implictly handled
idx, valid = x.st._expr_idx(Variable(idx, 0, prod(x.st.shape)))
@@ -195,12 +157,12 @@ class LLVMBuffer(ExplicitExecAST):
m = kernel_output_type(ir.Undefined)
if kernel_output_dim == 1:
return LLVMBuffer.op_lookup[x.op](builder, *values)
return LLVMCodegen.op_lookup[x.op](builder, *values)
else:
# TODO: this only has to be done for certain ops
for i in range(kernel_output_dim):
value = [builder.extract_element(v, int_const(i)) for v in values]
element = LLVMBuffer.op_lookup[x.op](builder, *value)
element = LLVMCodegen.op_lookup[x.op](builder, *value)
m = builder.insert_element(m, element, int_const(i))
return m
@@ -209,9 +171,9 @@ class LLVMBuffer(ExplicitExecAST):
# do the early ast
reduce_result = None
if k.reduceop:
reduce_input = ast_parse(loop_exit[-1], k.reduceop.src[0], -1)
phis = [LLVMBuffer.start_for_op[k.reduceop.op]] # type: ignore
if self.reduceop:
reduce_input = ast_parse(loop_exit[-1], self.reduceop.src[0], -1)
phis = [LLVMCodegen.start_for_op[self.reduceop.op]] # type: ignore
if kernel_output_dim > 1:
phis = [kernel_output_type(phis * kernel_output_dim)]
for i in range(store_loop+1, len(loop_entry)):
@@ -219,16 +181,16 @@ class LLVMBuffer(ExplicitExecAST):
val.add_incoming(phis[-1], loop_entry[i-1]._block)
phis.append(val)
if k.reduceop.op == ReduceOps.SUM:
if self.reduceop.op == ReduceOps.SUM:
reduce_result = loop_exit[-1].fadd(reduce_input, val, flags=('fast',))
elif k.reduceop.op == ReduceOps.MAX:
elif self.reduceop.op == ReduceOps.MAX:
reduce_result = loop_exit[-1].select(loop_exit[-1].fcmp_unordered(">", val, reduce_input, flags=('fast',)), val, reduce_input, flags=('fast',))
for i,phi in enumerate(phis[1:]):
phi.add_incoming(reduce_result, loop_exit[store_loop+1+i]._block)
# do the late ast
result = ast_parse(loop_exit[store_loop], k.ast, store_loop, reduce_result=reduce_result)
result = ast_parse(loop_exit[store_loop], self.ast, store_loop, reduce_result=reduce_result)
# store result
builder = loop_exit[store_loop]
@@ -247,5 +209,5 @@ class LLVMBuffer(ExplicitExecAST):
loop_entry[-1].branch(loop_exit[-1]._block)
loop_exit[0].ret_void()
LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
return k.ret
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))

View File

@@ -13,7 +13,7 @@ class TinyJit:
self.input_replace : Dict[DeviceBuffer, Any]= {}
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
if Device.DEFAULT not in ["GPU", "CLANG"]: return self.fxn(*args, **kwargs) # only jit on the GPU
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_tensors = {k:cast(DeviceBuffer, v.realize().lazydata.realized)._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_tensors) != 0, "no inputs to JIT"

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import sys, weakref, importlib, inspect
import os, sys, weakref, importlib, inspect
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape import ShapeTracker
@@ -14,21 +14,20 @@ sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZY = getenv("LAZY", 1)
def get_buffer(name, base='tinygrad.llops'):
def get_buffer(name, base='tinygrad.runtime'):
try:
return (name.upper(), [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0])
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0]
except Exception as e: # NOTE: this can't be put on one line due to mypy issue
print(name, "backend not available", e, file=sys.stderr)
class _Device:
def __init__(self) -> None:
self._buffers : Dict[str, Type[DeviceBuffer]] = {x[0]:x[1] for x in [
get_buffer('cpu'), get_buffer('gpu'), get_buffer('llvm'), get_buffer('torch'),
get_buffer('triton', 'accel.triton')] if x is not None}
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
self.DEFAULT : str = "CPU"
for name in self._buffers:
if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
self.__setattr__(name, name)
if self._buffers[name] is not None: self.__setattr__(name, name)
Device = _Device()
# TODO: movement ops that only change shape are really nops. treat them as such

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar
import functools, operator
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape import ShapeTracker
@@ -32,13 +32,35 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
if x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
_T = TypeVar("_T")
class RawBuffer:
def __init__(self, size): raise NotImplementedError("must be implemented")
@classmethod
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawBufferCopyIn(RawBuffer):
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@classmethod
def fromCPU(cls, x:np.ndarray):
ret = cls(4*prod(x.shape))
ret.copyin(x)
return ret
class RawBufferCopyInOut(RawBufferCopyIn):
size : int
def copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray:
x = np.empty((self.size//4), dtype=np.float32)
self.copyout(x)
return x
# a placeholder class to extend by the exec classes
class DeviceBuffer:
class DeviceBuffer(RawBuffer):
_buf: Any # underlying buffer
shape: Tuple[int, ...]
@staticmethod
def fromCPU(x:np.ndarray) -> DeviceBuffer: raise NotImplementedError("must be implemented")
def toCPU(self:DeviceBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer=None): raise NotImplementedError("must be implemented")
@@ -53,14 +75,14 @@ shape_fxn_for_op : Dict[Op, Callable] = {
**{op:functools.partial(lambda mop,self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops), op) for op in MovementOps}}
# used in CPUBuffer and TorchBuffer
class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
fxn_for_op : ClassVar = shape_fxn_for_op
# TODO: use generic types here to remove __init__ in specialized classes
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
def __init__(self, lbuf:Any): self._buf, self.shape = lbuf, tuple(lbuf.shape)
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg))
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None):
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[InterpretedBuffer]=None):
if FusedOps.MULACC in cls.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
srcs = [cls.exec_ast(x) if isinstance(x, LazyOp) else x for x in ast.src]
@@ -68,22 +90,57 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
if ast.op in BinaryOps: assert srcs[0].shape == srcs[1].shape, f"BinaryOps shape mismatch {srcs[0].shape} != {srcs[1].shape}"
if ast.op in ReduceOps: assert all(r == n or n == 1 for r,n in zip(srcs[0].shape, ast.arg)), f"ReduceOps can't reduce {srcs[0].shape} -> {ast.arg}"
if ast.op in MovementOps: ret = srcs[0].movement_op(ast.op, ast.arg)
else: ret = cls(cls.fxn_for_op[ast.op](*([x.buf for x in srcs] + ([ast.arg] if ast.arg else []))))
else: ret = cls(cls.fxn_for_op[ast.op](*([x._buf for x in srcs] + ([ast.arg] if ast.arg else []))))
if output_buffer is not None:
assert output_buffer.shape == ret.shape
output_buffer.buf = ret.buf
output_buffer._buf = ret._buf
return output_buffer
else:
return ret
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(map_buffers({x:GenericExecAST(GenericShape(x.shape)) for x in get_buffers(ast)}, ast)).buf
def get_lazyop_info(ast:LazyOp): return InterpretedBuffer.exec_ast(map_buffers({x:InterpretedBuffer(GenericShape(x.shape)) for x in get_buffers(ast)}, ast))._buf
# assumes you are using ShapeTracker
# used in GPUBuffer and LLVMBuffer
class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[CompiledBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape = self.st.shape
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
self._buf = hostbuf._buf if hostbuf is not None else None
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
if (self._backing is not None and self._backing.shape != (1,)) or force_create: self.raw()
# TODO: not GPUBuffer, get name of class
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
raw_buffer_type : Type[RawBuffer]
@classmethod
def create_raw_buffer(cls, shape, backing) -> RawBuffer:
assert backing is None or prod(shape) == prod(backing.shape), "backing has the wrong shape"
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing)
def raw(self) -> RawBuffer:
if self._buf is None: self._buf = self.create_raw_buffer(self._base_shape, self._backing)
self._backing = None
return self._buf
@classmethod
def fromCPU(cls, x:np.ndarray) -> CompiledBuffer: return cls(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
def toCPU(self) -> np.ndarray:
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
return self.contiguous().raw().toCPU().reshape(self.shape)
codegen_type : Any
runtime_type : Type
method_cache : Dict[str, Callable] = {}
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None):
k = cls.codegen_type(ast, output_buffer)
if k.key not in cls.method_cache: cls.method_cache[k.key] = k.codegen().build(cls.runtime_type)
prg = cls.method_cache[k.key]
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
prg(*k.bufs)
return k.ret
# universal for shape tracked
def contiguous(self): return self if self.st.contiguous and prod(self._base_shape) == prod(self.shape) else type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
@@ -92,12 +149,12 @@ class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method
class GlobalCounters:
global_ops : ClassVar[int] = 0
global_mem : ClassVar[int] = 0
time_sum : ClassVar[int] = 0
time_sum_s : ClassVar[float] = 0.0
kernel_count : ClassVar[int] = 0
mem_used : ClassVar[int] = 0 # NOTE: this is not reset
cache : ClassVar[Optional[list]] = None
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
@staticmethod
def log_kernel(op_estimate:int, mem_estimate:int):
GlobalCounters.kernel_count += 1

View File

@@ -1,38 +0,0 @@
import ctypes
import os
import numpy as np
import hashlib
import subprocess
from collections import defaultdict
from typing import List, Final, Dict
from tinygrad.helpers import DEBUG
import platform
OSX = platform.system() == "Darwin"
class CLBuffer:
def __init__(self, size): self._cl = (ctypes.c_float * (size))()
def copyin(self, b:np.ndarray): ctypes.memmove(self._cl, b.ctypes.data, b.size*4)
def copyout(self, a:np.ndarray):
np.copyto(a, np.ctypeslib.as_array(self._cl)[:a.size].reshape(a.shape))
class CLProgram:
kernel_prefix, buffer_prefix, smem_prefix, barrier = "", "", "", ""
gid = [f"gid[{i}]" for i in range(3)]
lid = [f"lid[{i}]" for i in range(3)]
extra_args : List[str] = []
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
# TODO: remove name, factor out op_estimate and mem_estimate
def __init__(self, name:str, prg:str, rename=True, op_estimate=0, mem_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
CLProgram.kernel_cnt[name] += 1
self.prg = prg.replace(f"{name}(", f"{self.name}(")
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n" + prg
if DEBUG >= 4: print(prg) # TODO: outside runtime!
# TODO: is there a way to not write this to disk?
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if OSX else 'so'}"
if not os.path.exists(fn):
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
os.rename(fn+".tmp", fn)
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]
def __call__(self, *args): self.fxn(*args[2:])

View File

@@ -1,37 +0,0 @@
from typing import Optional, List
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile # type: ignore
import numpy as np
from tinygrad.helpers import DEBUG
from tinygrad.ops import GlobalCounters
class CLBuffer:
def __init__(self, size): self._cl = cuda.mem_alloc(size)
def copyin(self, b:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, b, stream)
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self._cl)
class CLProgram:
kernel_prefix = "__global__"
buffer_prefix = ""
smem_prefix = "__shared__ "
barrier = "__syncthreads();"
float4 = "make_float4"
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
extra_args : List[str] = []
def __init__(self, name:str, prg:str, binary=False, shared=0, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate, self.shared = name, op_estimate, mem_estimate, shared
if DEBUG >= 4 and not binary: print("CUDA compile", prg)
if not binary: prg = compile(prg, target="ptx", no_extern_c=True).decode('utf-8')
if DEBUG >= 5: print(prg)
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
def __call__(self, global_size, local_size, *args):
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
global_size = global_size + [1] * (3 - len(global_size))
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
global_size = [x//y for x,y in zip(global_size, local_size)]
if DEBUG >= 2: print("CUDA launch", global_size, local_size)
self.prg(*args, block=tuple(local_size), grid=tuple(global_size), shared=self.shared)
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)

View File

@@ -1,98 +0,0 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-libdispatch
import Metal, Cocoa, libdispatch # type: ignore
import numpy as np
from typing import List, Any
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import prod, getenv, DEBUG
import subprocess, pathlib
METAL_XCODE = getenv("METAL_XCODE")
device = Metal.MTLCreateSystemDefaultDevice()
mtl_queue = device.newCommandQueue()
mtl_buffers_in_flight : List[Any] = []
def sync():
global mtl_buffers_in_flight
for cbuf in mtl_buffers_in_flight: cbuf.waitUntilCompleted()
mtl_buffers_in_flight = []
class CLBuffer:
def __init__(self, size): self._cl = device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def __del__(self): self._cl.release()
def copyin(self, b:np.ndarray):
np.copyto(np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32), b.reshape(-1).data)
def toCPU(self):
sync()
return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
# TODO: remove copyout everywhere
def copyout(self, a:np.ndarray): np.copyto(a, self.toCPU().reshape(a.shape))
class CLProgram:
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel"
buffer_prefix = "device "
smem_prefix = "threadgroup "
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
float4 = "float4"
gid = [f"gid.{chr(120+i)}" for i in range(3)]
lid = [f"lid.{chr(120+i)}" for i in range(3)]
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
def __init__(self, name:str, prg:str, op_estimate:int=0, mem_estimate:int=0):
self.name, self.op_estimate, self.mem_estimate = name, op_estimate, mem_estimate
if DEBUG >= 4: print("Metal compile", prg)
if DEBUG >= 6: # dump llvm
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
print(dis.decode('utf-8'))
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library, err = device.newLibraryWithData_error_(data, None)
else:
options = Metal.MTLCompileOptions.alloc().init()
self.library, err = device.newLibraryWithSource_options_error_(prg, options, None)
assert err is None, str(err)
self.fxn = self.library.newFunctionWithName_(name) #self.library.functionNames()[0]
# hacks to disassemble shader
if DEBUG >= 5:
arc, err = device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
assert err is None, str(err)
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
desc.setComputeFunction_(self.fxn)
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
assert err is None, str(err)
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
assert err is None, str(err)
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
import os
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
self.pipeline_state, err = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
assert err is None, str(err)
def __call__(self, global_size, local_size, *args):
global_size += [1] * (3-len(global_size))
if local_size is None: local_size = [32]
local_size += [1] * (3-len(local_size))
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(args):
encoder.setBuffer_offset_atIndex_(a, 0, i)
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
if DEBUG >= 2:
command_buffer.waitUntilCompleted()
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
print(f"METAL et {et*1e6:8.2f} us {self.name:28s} launch {str(global_size):18s} {local_size}")
GlobalCounters.time_sum += et
else:
mtl_buffers_in_flight.append(command_buffer)
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
return command_buffer

View File

@@ -1,101 +0,0 @@
import functools, platform
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, Tuple, List, ClassVar, Final
from collections import defaultdict
from tinygrad.ops import GlobalCounters
from tinygrad.helpers import getenv, DEBUG
OSX = platform.system() == "Darwin"
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
CLCACHE = getenv("CLCACHE", 1)
FLOAT16 = getenv("FLOAT16", 0)
class CL:
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
cl_ctx : ClassVar[Optional[cl.Context]] = None
cl_queue : ClassVar[Optional[cl.CommandQueue]] = None
def __init__(self) -> None:
if CL.cl_queue is not None: return # already initted
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
CL.cl_ctx = cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
@staticmethod
def enqueue_copy(a, b, is_blocking=False):
if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
class CLBuffer:
def __init__(self, size):
if DEBUG >= 4: print(f"allocate GPU Buffer {size}")
if len(CL.BUFFER_CACHE[size]) > 0:
self._cl = CL.BUFFER_CACHE[size].pop()
else:
# TODO: on GPU OOM, clear the cache
self._cl = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size)
GlobalCounters.mem_used += self._cl.size
def __del__(self):
if CLCACHE: CL.BUFFER_CACHE[self._cl.size].append(self._cl)
else: GlobalCounters.mem_used -= self._cl.size
def copyin(self, b:np.ndarray): CL.enqueue_copy(self._cl, b, False)
def copyout(self, a:np.ndarray): CL.enqueue_copy(a, self._cl, True)
class CLImage:
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
def __init__(self, shape):
self._cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
def __del__(self):
GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
def copyin(self, b:np.ndarray): raise NotImplementedError("no copyin for CLImage")
def copyout(self, a:np.ndarray): raise NotImplementedError("no copyout for CLImage")
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_prefix = "__kernel"
buffer_prefix = "__global "
smem_prefix = "__local "
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
gid = [f'get_global_id({i})' for i in range(3)]
lid = [f'get_local_id({i})' for i in range(3)]
extra_args : List[str] = []
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate
self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore
try:
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", self.prg)
raise e
if self.argdtypes is not None:
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
CLProgram.kernel_cnt[name] += 1
def __call__(self, *args) -> cl.Event:
if DEBUG >= 4: print(args[0], args[1], self.prg)
# print the PTX for NVIDIA. TODO: probably broken for everything else
if DEBUG >= 5 and not OSX: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8'))
e = self.clprg(CL().cl_queue, *args)
if DEBUG >= 2:
assert CL.cl_queue is not None
CL.cl_queue.finish()
# NOTE: Profiling is not in ns in OS X, we multiply by a computed ratio
et = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
GlobalCounters.time_sum += et
if DEBUG >= 1:
print(f"**CL** {GlobalCounters.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if DEBUG <= 1 else f"tm {et/1e3:9.2f}us/{GlobalCounters.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)"))
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
return e

View File

@@ -0,0 +1,37 @@
import ctypes
import os, time
import numpy as np
import hashlib
import subprocess
from collections import defaultdict
from typing import Final, Dict
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
from tinygrad.codegen.gpu import GPUCodegen
import platform
OSX = platform.system() == "Darwin"
class RawMallocBuffer(RawBufferCopyIn):
def __init__(self, size): self._buf = (ctypes.c_float * (size//4))()
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
def toCPU(self): return np.ctypeslib.as_array(self._buf)
class ClangProgram:
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
def __init__(self, name:str, prg:str):
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n" + prg
# TODO: is there a way to not write this to disk?
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if OSX else 'so'}"
if not os.path.exists(fn):
subprocess.check_output(['clang', '-shared', '-O2', '-Wall','-Werror', '-lm', '-fPIC', '-x', 'c', '-', '-o', fn+".tmp"], input=prg.encode('utf-8'))
os.rename(fn+".tmp", fn)
self.lib = ctypes.CDLL(fn)
self.fxn = self.lib[name]
def __call__(self, *args, wait=False):
if wait: st = time.monotonic()
self.fxn(*[x._buf for x in args[2:]])
if wait: return time.monotonic()-st
class ClangBuffer(CompiledBuffer):
raw_buffer_type = RawMallocBuffer
codegen_type = GPUCodegen # clang is the default
runtime_type = ClangProgram

View File

@@ -1,7 +1,7 @@
import numpy as np
import operator
from typing import ClassVar, Callable, Dict
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, GenericExecAST, Op
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, InterpretedBuffer, Op
from tinygrad.helpers import shape_to_axis
base_fxn_for_op : Dict[Op, Callable] = {
@@ -30,9 +30,9 @@ numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)
}}
class CPUBuffer(GenericExecAST):
class CPUBuffer(InterpretedBuffer):
fxn_for_op : ClassVar = numpy_fxn_for_op
@staticmethod
def fromCPU(x): return CPUBuffer(x)
def toCPU(x): return x.buf
def toCPU(x): return x._buf

View File

@@ -0,0 +1,38 @@
from typing import Optional
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile # type: ignore
import numpy as np
from tinygrad.helpers import DEBUG
from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size)
def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False):
if not binary: prg = compile(prg, target="ptx", no_extern_c=True).decode('utf-8')
if DEBUG >= 5: print(prg)
# TODO: name is wrong
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
def __call__(self, global_size, local_size, *args, wait=False):
local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1)
global_size = global_size + [1] * (3 - len(global_size))
assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}"
global_size = [x//y for x,y in zip(global_size, local_size)]
self.prg(*args, block=tuple(local_size), grid=tuple(global_size))
class CUDACodegen(GPUCodegen):
lang = GPULanguage(
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
class CUDABuffer(CompiledBuffer):
raw_buffer_type = RawCUDABuffer
codegen_type = CUDACodegen
runtime_type = CUDAProgram

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import platform, functools
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, List, ClassVar, Final
from collections import defaultdict
from tinygrad.helpers import IMAGE, DEBUG, getenv
from tinygrad.ops import CompiledBuffer, GlobalCounters, RawBufferCopyInOut, RawBuffer
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
OSX = platform.system() == "Darwin"
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
CLCACHE = getenv("CLCACHE", 1)
FLOAT16 = getenv("FLOAT16", 0)
class _CL:
@functools.cached_property
def cl_ctx(self) -> cl.Context:
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) # settle for CPU
if len(devices) > 1 or DEBUG >= 1: print(f"using {devices[getenv('CL_DEVICE', 0)]}")
return cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
@functools.cached_property
def cl_queue(self) -> cl.CommandQueue:
return cl.CommandQueue(CL.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
CL = _CL()
class CLBuffer(RawBufferCopyInOut):
# TODO: this can be in RawBuffer generically
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
def __init__(self, size):
self.size = size
if len(CLBuffer.BUFFER_CACHE[size]) > 0:
self._cl = CLBuffer.BUFFER_CACHE[size].pop()
else:
# TODO: on GPU OOM, clear the cache
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size)
GlobalCounters.mem_used += self._cl.size
def __del__(self):
if CLCACHE: CLBuffer.BUFFER_CACHE[self._cl.size].append(self._cl)
else: GlobalCounters.mem_used -= self._cl.size
def copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False)
def copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
class CLImage(RawBuffer):
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
IMAGE : Final = True
def __init__(self, shape):
self._cl = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
class CLProgram:
def __init__(self, name:str, prg:str, binary=False, argdtypes=None):
self.name, self.argdtypes, self.clprogram = name, argdtypes, cl.Program(CL.cl_ctx, CL.cl_ctx.devices, [prg]) if binary else cl.Program(CL.cl_ctx, prg) # type: ignore
try:
self._clprg = self.clprogram.build()
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
self.clprg = self._clprg.__getattr__(name)
if DEBUG >= 5 and not OSX: print(self.clprogram.get_info(cl.program_info.BINARIES)[0].decode('utf-8')) # print the PTX for NVIDIA. TODO: probably broken for everything else
if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes)
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
e = self.clprg(CL.cl_queue, global_size, local_size, *[x._cl if isinstance(x, (CLBuffer, CLImage)) else x for x in bufs])
if wait:
CL.cl_queue.finish()
return ((e.profile.end - e.profile.start) * OSX_TIMING_RATIO) * 1e-9
return None
class CLCodegen(GPUCodegen):
lang = GPULanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)])
class GPUBuffer(CompiledBuffer):
raw_buffer_type = CLBuffer
# override this method for image
@classmethod
def create_raw_buffer(cls, shape, backing) -> RawBuffer:
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and not backing: return CLImage(shape)
else: return super().create_raw_buffer(shape, backing)
codegen_type = CLCodegen
runtime_type = CLProgram

View File

@@ -1,13 +1,12 @@
import time, hashlib, ctypes
from typing import ClassVar
from tinygrad.ops import CompiledBuffer
from tinygrad.runtime.ops_clang import RawMallocBuffer
from tinygrad.helpers import getenv, DEBUG
from tinygrad.ops import GlobalCounters
import hashlib
import time
import ctypes
from ctypes import CFUNCTYPE
from tinygrad.codegen.llvm import LLVMCodegen
import llvmlite.binding as llvm # type: ignore
from llvmlite import ir # type: ignore
class LLVM:
target_machine : ClassVar[llvm.targets.TargetMachine] = None
@@ -15,8 +14,7 @@ class LLVM:
optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None
def __init__(self):
if LLVM.engine is not None:
return
if LLVM.engine is not None: return
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
@@ -46,42 +44,26 @@ class LLVM:
backing_mod.triple = llvm.get_process_triple()
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
# TODO: LLVMProgram
def exec(self, module:ir.Module, bufs, op_estimate=0, mem_estimate=0):
module.triple = llvm.get_process_triple()
module.data_layout = self.engine.target_data
llvm_ir = str(module)
if DEBUG >= 2:
print(llvm_ir)
mod = llvm.parse_assembly(llvm_ir)
mod.verify()
LLVM.optimizer.run(mod)
if DEBUG >= 4:
print("Optimized IR:")
print(str(mod))
mod.name = hashlib.sha1(llvm_ir.encode('utf-8')).hexdigest()
if DEBUG >= 3:
print(LLVM.target_machine.emit_assembly(mod))
LLVM.engine.add_module(mod)
class LLVMProgram:
def __init__(self, name:str, prg:str, binary=False):
self.mod = llvm.parse_assembly(prg)
self.mod.verify()
LLVM().optimizer.run(self.mod)
self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
LLVM.engine.add_module(self.mod)
LLVM.engine.finalize_object()
self.fxn = LLVM.engine.get_function_address(name)
# call function (NOTE: if the types don't match, there's likely something wrong with the cache)
#cfunc = CFUNCTYPE(ctypes.c_int, *[type(x._buf) for x in bufs])(LLVM.engine.get_function_address('exec'))
def __del__(self): LLVM.engine.remove_module(self.mod)
# why is this needed without the types. fixed tests below
# LLVM=1 OPT=2 python3 test/test_ops.py TestOps.test_cat TestOps.test_multicat
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for x in bufs])(LLVM.engine.get_function_address('exec'))
st = time.monotonic()
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.POINTER(ctypes.c_float) for _ in bufs])(self.fxn)
if wait: st = time.monotonic()
cfunc(*[x._buf for x in bufs])
et = time.monotonic() - st
if DEBUG >= 1:
print(f"**LLVM** time {et*1000:7.2f} ms OPs {op_estimate/1e6:7.2f}M -- {(op_estimate/1e9)/et:5.2f} GFLOPS -- {mem_estimate:10d} reads -- {(mem_estimate*4/1e9)/et:5.2f} GB/s")
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if wait: return time.monotonic()-st
# we are done
LLVM.engine.remove_module(mod)
return cfunc
class LLVMBuffer(CompiledBuffer):
raw_buffer_type = RawMallocBuffer
codegen_type = LLVMCodegen
runtime_type = LLVMProgram

View File

@@ -0,0 +1,92 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-libdispatch
import os, subprocess, pathlib, functools
import Metal, Cocoa, libdispatch # type: ignore
import numpy as np
from typing import List, Any
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
from tinygrad.helpers import prod, getenv, DEBUG
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
METAL_XCODE = getenv("METAL_XCODE")
class _METAL:
mtl_buffers_in_flight : List[Any] = []
@functools.cached_property
def device(self):
return Metal.MTLCreateSystemDefaultDevice()
@functools.cached_property
def mtl_queue(self):
return METAL.device.newCommandQueue()
METAL = _METAL()
class RawMetalBuffer(RawBufferCopyIn):
def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def __del__(self): self._cl.release()
def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data)
def toCPU(self) -> np.ndarray:
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
METAL.mtl_buffers_in_flight = []
return self._as_np() # no copy!
class MetalProgram:
def __init__(self, name:str, prg:str):
if DEBUG >= 6: # dump llvm
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
print(dis.decode('utf-8'))
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library, err = METAL.device.newLibraryWithData_error_(data, None)
else:
options = Metal.MTLCompileOptions.alloc().init()
self.library, err = METAL.device.newLibraryWithSource_options_error_(prg, options, None)
assert err is None, str(err)
self.fxn = self.library.newFunctionWithName_(name)
# hacks to disassemble shader
if DEBUG >= 5:
arc, err = METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
assert err is None, str(err)
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
desc.setComputeFunction_(self.fxn)
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
assert err is None, str(err)
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
assert err is None, str(err)
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
self.pipeline_state, err = METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)
assert err is None, str(err)
def __call__(self, global_size, local_size, *bufs, wait=False):
global_size += [1] * (3-len(global_size))
if local_size is None: local_size = [32]
local_size += [1] * (3-len(local_size))
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._cl, 0, i)
encoder.dispatchThreads_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
if wait:
command_buffer.waitUntilCompleted()
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
class MetalCodegen(GPUCodegen):
lang = GPULanguage(
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4",
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
class MetalBuffer(CompiledBuffer):
raw_buffer_type = RawMetalBuffer
codegen_type = MetalCodegen
runtime_type = MetalProgram

View File

@@ -1,8 +1,8 @@
import torch
from typing import ClassVar, Final, Dict, Callable
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, GenericExecAST, Op
from typing import ClassVar, Dict, Callable
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, InterpretedBuffer, Op
from tinygrad.helpers import getenv
from tinygrad.llops.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
@@ -12,10 +12,9 @@ torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
}}
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(GenericExecAST):
class TorchBuffer(InterpretedBuffer):
fxn_for_op : ClassVar = torch_fxn_for_op
SUPPORTS_SIMPLE_PADDING : Final = True
@staticmethod
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
def toCPU(x): return x.buf.cpu().numpy()
def toCPU(x): return x._buf.cpu().numpy()