mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add renderer class (#4524)
* add renderer class * tests pass * fix pylint * fix tensor cores
This commit is contained in:
@@ -55,18 +55,18 @@ if __name__ == "__main__":
|
||||
lins:List[Linearizer] = []
|
||||
|
||||
# always try hand coded opt
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
|
||||
# maybe try tensor cores
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.renderer)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
|
||||
# try a beam search
|
||||
if beam:=getenv("BEAM"):
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.renderer)
|
||||
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lins.append(lin)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
for ast in ast_dedup:
|
||||
k = Device["CLANG"].get_linearizer(*ast)
|
||||
k.linearize()
|
||||
src = Device["CLANG"].compiler.compiler_opts.renderer(to_function_name(k.name), k.uops)
|
||||
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
|
||||
srcs[ast] = (k.name, src)
|
||||
print("functions:", len(srcs))
|
||||
used_buffers = dedup(flatten([si.bufs for si in sched]))
|
||||
|
||||
@@ -25,7 +25,7 @@ if __name__ == '__main__':
|
||||
|
||||
for i, ast_str in enumerate(ast_strs):
|
||||
print(f"optimizing {i}/{len(ast_strs)}\nast={ast_str}")
|
||||
lin = ast_str_to_lin(ast_str, opts=device.compiler.compiler_opts)
|
||||
lin = ast_str_to_lin(ast_str, opts=device.renderer)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
|
||||
|
||||
2
test/external/external_test_hip_compile.py
vendored
2
test/external/external_test_hip_compile.py
vendored
@@ -32,7 +32,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
|
||||
compile_hip(code)
|
||||
return (time.perf_counter() - st) * 1000
|
||||
|
||||
tinygrad_tm = min([time_compile(Device[Device.DEFAULT].compiler.compiler_opts.renderer(f"test{i}", lin.uops)) for i in range(10)])
|
||||
tinygrad_tm = min([time_compile(Device[Device.DEFAULT].renderer.render(f"test{i}", lin.uops)) for i in range(10)])
|
||||
ref_tm = min([time_compile(reference.format(name=f"test{i}")) for i in range(10)])
|
||||
print(f"tinygrad {tinygrad_tm:6.2f} ms")
|
||||
print(f"reference {ref_tm:6.2f} ms")
|
||||
|
||||
2
test/external/speed_beam_v_hcopt.py
vendored
2
test/external/speed_beam_v_hcopt.py
vendored
@@ -15,7 +15,7 @@ if __name__ == "__main__":
|
||||
beam_won, tested = 0, 0
|
||||
|
||||
for num, ast in enumerate(ast_strs[:test_n]):
|
||||
def new_lin(): return ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
|
||||
def new_lin(): return ast_str_to_lin(ast, opts=dev.renderer)
|
||||
|
||||
k = new_lin()
|
||||
# k.required_optimizations()
|
||||
|
||||
4
test/external/speed_compare_cuda_nv.py
vendored
4
test/external/speed_compare_cuda_nv.py
vendored
@@ -23,12 +23,12 @@ if __name__ == "__main__":
|
||||
average_tm_cuda, average_tm_nv = 0, 0
|
||||
for num,ast in enumerate(ast_strs):
|
||||
# cuda compile
|
||||
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
|
||||
culin = ast_str_to_lin(ast, opts=cudev.renderer)
|
||||
culin.hand_coded_optimizations()
|
||||
cuda_prg = cudev.to_runner(culin)
|
||||
cubufs = bufs_from_lin(culin)
|
||||
|
||||
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
|
||||
nvlin = ast_str_to_lin(ast, opts=nvdev.renderer)
|
||||
nvlin.hand_coded_optimizations()
|
||||
nv_prg = nvdev.to_runner(nvlin)
|
||||
nvbufs = bufs_from_lin(nvlin)
|
||||
|
||||
2
test/external/speed_compare_cuda_ptx.py
vendored
2
test/external/speed_compare_cuda_ptx.py
vendored
@@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||
average_tm_cuda, average_tm_ptx = 0, 0
|
||||
for num,ast in enumerate(ast_strs):
|
||||
# cuda compile
|
||||
lin = ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
|
||||
lin = ast_str_to_lin(ast, opts=dev.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
cuda_prg = dev.to_runner(lin)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ class TestDeviceSpeed(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dev = Device[Device.DEFAULT]
|
||||
cls.empty = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", UOpGraph())
|
||||
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph())
|
||||
|
||||
def test_empty_compile(self):
|
||||
with Timing("compiler "):
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
def test_reduce_upcast(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
|
||||
if not Device[Device.DEFAULT].renderer.supports_float4:
|
||||
self.skipTest("device does not support upcast")
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
@@ -157,7 +157,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4)
|
||||
|
||||
def test_upcast_with_locals(self):
|
||||
if not (opts:=Device[Device.DEFAULT].compiler.compiler_opts).has_local or not opts.has_shared or not opts.supports_float4:
|
||||
if not (opts:=Device[Device.DEFAULT].renderer).has_local or not opts.has_shared or not opts.supports_float4:
|
||||
self.skipTest("device does not support upcasted reduce with locals")
|
||||
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
@@ -218,16 +218,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
|
||||
def test_tensor_cores(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
def test_tensor_cores_padded(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
pad = 1
|
||||
|
||||
@@ -251,9 +251,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "RHIP", "RHIP is really slow here")
|
||||
def test_tensor_cores_multi_reduce(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
|
||||
golden_result = None
|
||||
@@ -358,7 +358,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
helper(Tensor.arange(256), max_ops=2)
|
||||
helper(Tensor.arange(255), max_ops=0)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(k):
|
||||
@@ -567,7 +567,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49
|
||||
|
||||
def test_matvec(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local:
|
||||
if not Device[Device.DEFAULT].renderer.has_local:
|
||||
self.skipTest("Only devices with locals")
|
||||
N = 128
|
||||
a = Tensor.rand(1, N).realize()
|
||||
@@ -618,7 +618,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
def test_local_and_grouped_reduce(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
|
||||
N = 128
|
||||
@@ -664,7 +664,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
])
|
||||
|
||||
def test_matmul(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
|
||||
N = 128
|
||||
@@ -694,7 +694,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
])
|
||||
|
||||
def test_double_reduce(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
|
||||
N = 128
|
||||
@@ -721,7 +721,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
])
|
||||
|
||||
def test_invalid_tensor_core_extra_opts(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
@@ -742,25 +742,25 @@ class TestKernelOpts(unittest.TestCase):
|
||||
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
|
||||
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
k = Linearizer(ast, opts=Device[Device.DEFAULT].compiler.compiler_opts)
|
||||
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
||||
def test_tensor_core_opts(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
|
||||
if tc.dtype_in == dtypes.bfloat16: continue
|
||||
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
|
||||
@@ -889,7 +889,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
])
|
||||
|
||||
def test_color_shapes_with_local(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
|
||||
N = 32
|
||||
@@ -959,7 +959,7 @@ class TestLinearizerHelper(unittest.TestCase):
|
||||
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
|
||||
|
||||
class TestLinearizerUOptimize(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "device doesn't support float4")
|
||||
def test_grouped_store_phis(self):
|
||||
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
|
||||
out = x.matmul(y)
|
||||
@@ -973,7 +973,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "device doesn't support float4")
|
||||
def test_grouped_store_values(self):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
@@ -986,8 +986,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
|
||||
|
||||
def test_grouped_store_locals_and_globals(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
|
||||
not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared or \
|
||||
not Device[Device.DEFAULT].renderer.supports_float4:
|
||||
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
|
||||
|
||||
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
|
||||
@@ -1011,8 +1011,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
|
||||
|
||||
def test_grouped_store_local_only(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
|
||||
not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared or \
|
||||
not Device[Device.DEFAULT].renderer.supports_float4:
|
||||
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
|
||||
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
@@ -1031,7 +1031,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
assert stores[1].vin[-1].dtype == dtypes.float
|
||||
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.supports_float4:
|
||||
self.skipTest("Needs locals and float4")
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [
|
||||
@@ -1047,7 +1047,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
|
||||
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.supports_float4:
|
||||
self.skipTest("Needs locals and float4")
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
|
||||
|
||||
@@ -12,8 +12,8 @@ from tinygrad.codegen.uops import exec_alu, UOpGraph
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", uops)
|
||||
has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops))
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
import math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
|
||||
from tinygrad.device import Device, CompilerOptions
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType
|
||||
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -70,9 +71,8 @@ class LocalBuffer(NamedTuple):
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, *ast:LazyOp, opts:Optional[CompilerOptions]=None):
|
||||
self.opts = opts if opts is not None else (device.compiler.compiler_opts if (device:=Device[Device.DEFAULT]).compiler is not None else
|
||||
CompilerOptions(Device.DEFAULT))
|
||||
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
|
||||
assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
|
||||
self.ast = ast
|
||||
|
||||
@@ -5,11 +5,12 @@ from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
from tinygrad.features.image import to_image_idx
|
||||
from tinygrad.renderer import Program
|
||||
|
||||
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
||||
|
||||
@@ -441,3 +442,12 @@ class Linearizer(Kernel):
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
def to_program(self) -> Program:
|
||||
self.linearize()
|
||||
info = get_lazyop_info(self.ast[0])
|
||||
ops, mem = self.uops.flops_mem()
|
||||
run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
return Program(self.name, self.opts.render(to_function_name(self.name), self.uops), self.opts.device,
|
||||
self.global_size, self.local_size, self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
|
||||
@@ -2,13 +2,13 @@ from __future__ import annotations
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable, Any
|
||||
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, Any
|
||||
import importlib, inspect, functools, pathlib, os, ctypes
|
||||
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.helpers import getenv, all_int, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.ops import LazyOp, get_lazyop_info
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.ops import LazyOp
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
@@ -181,63 +181,7 @@ class Runner:
|
||||
|
||||
# **************** for Compiled Devices ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Program:
|
||||
name:str
|
||||
src:str
|
||||
dname:str
|
||||
global_size:Optional[List[int]]=None
|
||||
local_size:Optional[List[int]]=None
|
||||
uops:Optional[UOpGraph]=None
|
||||
op_estimate:sint=0
|
||||
mem_estimate:sint=0
|
||||
|
||||
@functools.cached_property
|
||||
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
|
||||
|
||||
@functools.cached_property
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
|
||||
|
||||
@functools.cached_property
|
||||
def outcount(self) -> int: return sum(x[1] for x in self.globals)
|
||||
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
def launch_dims(self, var_vals:Dict[Variable, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompilerOptions:
|
||||
device: str = ""
|
||||
suffix: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
has_tensor_cores: bool = False
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
shared_max: int = 32768
|
||||
renderer: Callable = fake_renderer
|
||||
|
||||
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
|
||||
k.linearize()
|
||||
info = get_lazyop_info(k.ast[0])
|
||||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
return Program(k.name, self.renderer(to_function_name(k.name), k.uops),
|
||||
override_device if override_device else self.device,
|
||||
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
|
||||
class Compiler:
|
||||
compiler_opts: ClassVar[CompilerOptions]
|
||||
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
||||
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
||||
def compile_cached(self, src:str) -> bytes:
|
||||
@@ -276,24 +220,25 @@ class CompiledRunner(Runner):
|
||||
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
class Compiled:
|
||||
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
||||
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
|
||||
self.renderer = renderer if renderer else Renderer()
|
||||
def synchronize(self): pass # override this in your device
|
||||
|
||||
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.compiler_opts.to_program(k, override_device=self.dname))
|
||||
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(replace(k.to_program(), dname=self.dname))
|
||||
|
||||
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.features.graph import print_tree
|
||||
for op in ast: print_tree(op)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(*ast, opts=self.compiler.compiler_opts)
|
||||
k = Linearizer(*ast, opts=self.renderer)
|
||||
k.required_optimizations()
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
|
||||
kb, k_opt = Linearizer(*ast, opts=self.compiler.compiler_opts), k
|
||||
kb, k_opt = Linearizer(*ast, opts=self.renderer), k
|
||||
kb.required_optimizations()
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
@@ -301,7 +246,7 @@ class Compiled:
|
||||
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
||||
lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(*ast, opts=self.compiler.compiler_opts)))
|
||||
lins.append(("hc", Linearizer(*ast, opts=self.renderer)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
|
||||
@@ -51,7 +51,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
|
||||
try:
|
||||
x[1].linearize()
|
||||
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
||||
p = compiler.compiler_opts.to_program(x[1])
|
||||
p = x[1].to_program()
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(p.src)
|
||||
et = time.perf_counter() - st
|
||||
@@ -174,7 +174,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
p = dev.compiler.compiler_opts.to_program(lin)
|
||||
p = lin.to_program()
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
|
||||
49
tinygrad/renderer/__init__.py
Normal file
49
tinygrad/renderer/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Optional, List, Tuple, Dict
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import to_function_name
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Program:
|
||||
name:str
|
||||
src:str
|
||||
dname:str
|
||||
global_size:Optional[List[int]]=None
|
||||
local_size:Optional[List[int]]=None
|
||||
uops:Optional[UOpGraph]=None
|
||||
op_estimate:sint=0
|
||||
mem_estimate:sint=0
|
||||
|
||||
@functools.cached_property
|
||||
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
|
||||
|
||||
@functools.cached_property
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
|
||||
|
||||
@functools.cached_property
|
||||
def outcount(self) -> int: return sum(x[1] for x in self.globals)
|
||||
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
def launch_dims(self, var_vals:Dict[Variable, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
class Renderer:
|
||||
device: str = ""
|
||||
suffix: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
has_tensor_cores: bool = False
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
shared_max: int = 32768
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
|
||||
@@ -1,15 +1,16 @@
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
|
||||
import functools, struct, copy
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct, copy
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
||||
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
||||
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
@@ -33,191 +34,16 @@ def ptr_ar(root, uops):
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
class AssemblyLanguage(NamedTuple):
|
||||
kernel_prefix: str = ""
|
||||
barrier: str = ""
|
||||
load_global: bool = False
|
||||
label_prefix: str = ""
|
||||
gid: List[str] = []
|
||||
gdim: List[str] = []
|
||||
lid: List[str] = []
|
||||
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
|
||||
asm_for_op: Dict[Op, Callable[...,str]] = {}
|
||||
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
||||
supports_half: List[Op] = []
|
||||
class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
||||
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
|
||||
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
||||
def mem_type(self, dtype) -> str: raise NotImplementedError()
|
||||
|
||||
def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str:
|
||||
# editing the uops breaks beam search
|
||||
uops = copy.deepcopy(_uops)
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
|
||||
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
|
||||
c[prefix] += 1
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in lang.const_requires_mov:
|
||||
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
|
||||
return out
|
||||
return lang.render_const(x, dtype)
|
||||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
else:
|
||||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
if lang.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
# language options
|
||||
kernel_prefix = """.version VERSION
|
||||
.target TARGET
|
||||
.address_size 64
|
||||
@@ -229,7 +55,7 @@ class PTXLanguage(AssemblyLanguage):
|
||||
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
||||
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||
asm_for_op = {
|
||||
asm_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
@@ -245,13 +71,14 @@ class PTXLanguage(AssemblyLanguage):
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
||||
TernaryOps.WHERE]
|
||||
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
||||
types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
||||
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
||||
|
||||
const_requires_mov = [dtypes.half, dtypes.bool]
|
||||
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
|
||||
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
||||
val = render_val(x, dtype)
|
||||
@@ -270,7 +97,7 @@ class PTXLanguage(AssemblyLanguage):
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
||||
assert dtype is not dtypes.bool
|
||||
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
||||
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
|
||||
@@ -291,4 +118,160 @@ class PTXLanguage(AssemblyLanguage):
|
||||
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
||||
"\n}")
|
||||
|
||||
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
|
||||
def render(self, name:str, _uops:UOpGraph) -> str:
|
||||
# editing the uops breaks beam search
|
||||
uops = copy.deepcopy(_uops)
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
|
||||
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in self.asm_for_op.keys() if op not in self.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
uops.optimize_loops()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
|
||||
c[prefix] += 1
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in self.const_requires_mov:
|
||||
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
||||
return out
|
||||
return self.render_const(x, dtype)
|
||||
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{u.arg}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
else:
|
||||
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
if self.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import strip_parens, getenv, prod
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
class CStyleLanguage(Renderer):
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
@@ -17,8 +18,6 @@ class CStyleLanguage(NamedTuple):
|
||||
arg_int_prefix: str = "const int"
|
||||
barrier: str = ""
|
||||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
global_max: List[int] = []
|
||||
local_max: List[int] = []
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
@@ -88,100 +87,107 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
||||
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
|
||||
child_count = Counter(v for ru in uops for v in ru.vin)
|
||||
child_count = Counter(v for ru in uops for v in ru.vin)
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.BARRIER: kk(lang.barrier)
|
||||
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
|
||||
else: operands = [r[v] for v in vin]
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
val = lang.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
elif uop is UOps.BARRIER: kk(self.barrier)
|
||||
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
|
||||
else: operands = [r[v] for v in vin]
|
||||
val = self.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa('precast')
|
||||
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
val = self.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(self.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, uops)
|
||||
return self.render_kernel(name, kernel, bufs, uops)
|
||||
|
||||
class ClangLanguage(CStyleLanguage):
|
||||
class ClangRenderer(CStyleLanguage):
|
||||
device = "CLANG"
|
||||
supports_float4 = False
|
||||
has_local = False
|
||||
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
||||
def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops)
|
||||
|
||||
class OpenCLLanguage(CStyleLanguage):
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "GPU"
|
||||
|
||||
# language options
|
||||
kernel_prefix = "__kernel "
|
||||
buffer_prefix = "__global "
|
||||
smem_align = "__attribute__ ((aligned (16))) "
|
||||
@@ -197,9 +203,13 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops)
|
||||
|
||||
class MetalLanguage(CStyleLanguage):
|
||||
class MetalRenderer(CStyleLanguage):
|
||||
device = "METAL"
|
||||
has_tensor_cores=os.uname().machine == "arm64"
|
||||
shared_max=32768
|
||||
|
||||
# language options
|
||||
kernel_prefix = "kernel "
|
||||
buffer_prefix = "device "
|
||||
smem_prefix = "threadgroup "
|
||||
@@ -227,7 +237,6 @@ class MetalLanguage(CStyleLanguage):
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops)
|
||||
|
||||
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
@@ -240,7 +249,15 @@ def _make_cuda_dtype(base_type, name, cnt):
|
||||
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
|
||||
# language options
|
||||
kernel_prefix = "extern \"C\" __global__ "
|
||||
smem_prefix = "__shared__ "
|
||||
smem_prefix_for_cast = False
|
||||
@@ -271,7 +288,6 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}
|
||||
return c;}}""")
|
||||
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops)
|
||||
|
||||
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
@@ -295,7 +311,12 @@ def _make_hip_dtype(base_type, name, cnt):
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
||||
|
||||
class HIPLanguage(CStyleLanguage):
|
||||
class HIPRenderer(CStyleLanguage):
|
||||
device = "HSA"
|
||||
has_tensor_cores = True
|
||||
shared_max = 65536
|
||||
|
||||
# language options
|
||||
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
@@ -357,5 +378,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops)
|
||||
|
||||
@@ -4,6 +4,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
@@ -65,88 +66,95 @@ def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
|
||||
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
supports_float4=False
|
||||
has_local=False
|
||||
has_shared=False
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
||||
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(lvars[vin[3]]): store_op()
|
||||
else: store_op()
|
||||
elif uop is UOps.ENDLOOP:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(vin) > 2:
|
||||
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
if len(vin) > 3:
|
||||
with bb[-1].if_then(lvars[vin[3]]):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
lvars[u] = val
|
||||
elif uop is UOps.PHI:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
elif uop is UOps.ENDLOOP:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(vin) > 2:
|
||||
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
lvars[u] = val
|
||||
elif uop is UOps.PHI:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
|
||||
@@ -6,8 +6,8 @@ from tinygrad.device import Buffer, Device, CompiledRunner
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_clang import ClangProgram
|
||||
from tinygrad.renderer.cstyle import ClangLanguage
|
||||
render_dtype = ClangLanguage().render_dtype
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
||||
class ClangGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, List, Any, cast
|
||||
import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
from tinygrad.runtime.ops_hsa import HSACompiler
|
||||
import tinygrad.runtime.autogen.kfd as kfd
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
import tinygrad.runtime.autogen.amd_gpu as amd_gpu
|
||||
@@ -68,7 +69,6 @@ def create_sdma_packets():
|
||||
sdma_pkts = create_sdma_packets()
|
||||
|
||||
class AMDCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("AMD", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
super().__init__(f"compile_hip_{self.arch}")
|
||||
@@ -583,7 +583,8 @@ class AMDDevice(Compiled):
|
||||
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
super().__init__(device, AMDAllocator(self), HIPRenderer(), HSACompiler(self.arch),
|
||||
functools.partial(AMDProgram, self),
|
||||
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
|
||||
|
||||
def synchronize(self):
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import ctypes, subprocess, pathlib, tempfile
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator
|
||||
from tinygrad.helpers import cpu_time_execution
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
|
||||
class ClangCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("CLANG", supports_float4=False, has_local=False, renderer=ClangRenderer)
|
||||
def compile(self, src:str) -> bytes:
|
||||
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
|
||||
with tempfile.NamedTemporaryFile(delete=True) as output_file:
|
||||
@@ -25,4 +24,4 @@ class ClangProgram:
|
||||
class ClangDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
from tinygrad.runtime.graph.clang import ClangGraph
|
||||
super().__init__(device, MallocAllocator, ClangCompiler("compile_clang"), ClangProgram, ClangGraph)
|
||||
super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler("compile_clang"), ClangProgram, ClangGraph)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
|
||||
from pathlib import Path
|
||||
from dataclasses import replace
|
||||
from typing import Tuple, Optional, List
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator, MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator, MallocAllocator
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.assembly import PTXRenderer
|
||||
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401
|
||||
@@ -53,21 +52,17 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
||||
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
|
||||
|
||||
class PTXCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
|
||||
shared_max=49152, renderer=PTXRenderer)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
self.version = "7.8" if arch >= "sm_89" else "7.5"
|
||||
PTXCompiler.compiler_opts = replace(PTXCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
#PTXCompiler.compiler_opts = replace(PTXCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
super().__init__(f"compile_ptx_{self.arch}")
|
||||
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
|
||||
|
||||
class CUDACompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
|
||||
shared_max=49152, renderer=CUDARenderer)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
CUDACompiler.compiler_opts = replace(CUDACompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
#CUDACompiler.compiler_opts = replace(CUDACompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
|
||||
self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
|
||||
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
|
||||
@@ -176,6 +171,7 @@ class CUDADevice(Compiled):
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
PTXRenderer(self.arch) if getenv("PTX") else CUDARenderer(self.arch),
|
||||
PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
|
||||
functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class DiskDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.size: Optional[int] = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None)
|
||||
super().__init__(device, DiskAllocator(self), None, None, None)
|
||||
def _might_open(self, size):
|
||||
self.count += 1
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
|
||||
@@ -4,7 +4,7 @@ import ctypes, functools, hashlib
|
||||
import tinygrad.runtime.autogen.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompilerOptions
|
||||
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler
|
||||
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
@@ -14,7 +14,6 @@ def check(status):
|
||||
def checked(ret, status): return (check(status.value), ret)[1]
|
||||
|
||||
class CLCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("GPU", renderer=OpenCLRenderer)
|
||||
def __init__(self, device:CLDevice, compile_key:str):
|
||||
self.device = device
|
||||
super().__init__(f"compile_cl_{compile_key}")
|
||||
@@ -96,7 +95,7 @@ class CLDevice(Compiled):
|
||||
self.pending_copyin: List[memoryview] = []
|
||||
|
||||
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
|
||||
super().__init__(device, CLAllocator(self), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
|
||||
super().__init__(device, CLAllocator(self), OpenCLRenderer(), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
@@ -3,7 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
@@ -42,7 +42,6 @@ class HSAProfiler:
|
||||
Profiler = HSAProfiler()
|
||||
|
||||
class HSACompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("HSA", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
super().__init__(f"compile_hip_{self.arch}")
|
||||
@@ -219,7 +218,7 @@ class HSADevice(Compiled):
|
||||
self.reusable_signals: List[hsa.hsa_signal_t] = []
|
||||
|
||||
from tinygrad.runtime.graph.hsa import HSAGraph
|
||||
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
|
||||
super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
|
||||
|
||||
# Finish init: preallocate some signals + space for kernargs
|
||||
self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools
|
||||
from typing import Tuple
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator
|
||||
from tinygrad.helpers import DEBUG, cpu_time_execution
|
||||
from tinygrad.renderer.llvmir import uops_to_llvm_ir
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
import llvmlite.binding as llvm
|
||||
|
||||
class LLVMCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False, renderer=uops_to_llvm_ir)
|
||||
def __init__(self, device:LLVMDevice):
|
||||
self.device = device
|
||||
super().__init__("compile_llvm")
|
||||
@@ -43,4 +42,4 @@ class LLVMDevice(Compiled):
|
||||
backing_mod = llvm.parse_assembly(str())
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
|
||||
super().__init__(device, MallocAllocator, LLVMCompiler(self), functools.partial(LLVMProgram, self))
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
import Metal, libdispatch
|
||||
from typing import List, Set, Any, Tuple, Optional
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
@@ -12,7 +12,6 @@ def wait_check(cbuf: Any):
|
||||
raise RuntimeError(error)
|
||||
|
||||
class MetalCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64", shared_max=32768, renderer=MetalRenderer)
|
||||
def __init__(self, device:Optional[MetalDevice]):
|
||||
self.device = device
|
||||
super().__init__("compile_metal")
|
||||
@@ -97,7 +96,7 @@ class MetalDevice(Compiled):
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
self.track_cross_buffer: List[Any] = []
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
|
||||
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
|
||||
functools.partial(MetalProgram, self), MetalGraph)
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||
|
||||
@@ -6,4 +6,4 @@ class NpyAllocator(Allocator):
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
||||
|
||||
class NpyDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None)
|
||||
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
|
||||
from typing import Tuple, List, Any, cast
|
||||
from dataclasses import replace
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator, BufferOptions
|
||||
from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes
|
||||
@@ -65,10 +64,9 @@ def nvdata64(data): return (data >> 32, data & 0xFFFFFFFF)
|
||||
def nvdata64_le(data): return (data & 0xFFFFFFFF, data >> 32)
|
||||
|
||||
class NVCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("NV", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152, renderer=CUDARenderer)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
NVCompiler.compiler_opts = replace(NVCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
#NVCompiler.compiler_opts = replace(NVCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
|
||||
cuda_check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
|
||||
self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
|
||||
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
|
||||
@@ -493,7 +491,7 @@ class NVDevice(Compiled):
|
||||
self.arch: str = 'sm_89' # TODO: fix
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self),
|
||||
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), NVCompiler(self.arch), functools.partial(NVProgram, self),
|
||||
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
|
||||
|
||||
self._cmdq_setup_compute_gpfifo()
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions, Allocator
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def _load(m, i):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
@@ -177,15 +178,18 @@ class PythonProgram:
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
def PythonRenderer(name:str, uops:UOpGraph) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
class PythonRenderer(Renderer):
|
||||
device = "PYTHON"
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.has_tensor_cores = "METAL", True
|
||||
if getenv("EMULATE_HSA"): self.device, self.has_tensor_cores = "HSA", True
|
||||
if getenv("EMULATE_CUDA"): self.device, self.has_tensor_cores = "CUDA", True
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
|
||||
class PythonCompiler(Compiler):
|
||||
compiler_opts = CompilerOptions("METAL", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_METAL") else \
|
||||
(CompilerOptions("HSA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_HSA") else \
|
||||
(CompilerOptions("CUDA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_CUDA") else \
|
||||
CompilerOptions("PYTHON", renderer=PythonRenderer)))
|
||||
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
||||
|
||||
class PythonAllocator(Allocator):
|
||||
@@ -195,4 +199,4 @@ class PythonAllocator(Allocator):
|
||||
|
||||
class PythonDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram)
|
||||
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import ctypes
|
||||
from tinygrad.device import Compiled, MallocAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.ops_hsa import HSACompiler
|
||||
|
||||
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
|
||||
@@ -14,4 +15,4 @@ class RHIPProgram:
|
||||
class RHIPDevice(Compiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.device = int(device.split(":")[1]) if ":" in device else 0
|
||||
super().__init__(device, MallocAllocator, HSACompiler("gfx1100"), RHIPProgram)
|
||||
super().__init__(device, MallocAllocator, HIPRenderer(), HSACompiler("gfx1100"), RHIPProgram)
|
||||
|
||||
Reference in New Issue
Block a user