add renderer class (#4524)

* add renderer class

* tests pass

* fix pylint

* fix tensor cores
This commit is contained in:
George Hotz
2024-05-10 21:40:02 -07:00
committed by GitHub
parent b00b6b16f0
commit 347a3acb37
31 changed files with 536 additions and 527 deletions

View File

@@ -55,18 +55,18 @@ if __name__ == "__main__":
lins:List[Linearizer] = []
# always try hand coded opt
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lins.append(lin)
# maybe try tensor cores
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append(lin)
# try a beam search
if beam:=getenv("BEAM"):
lin = Linearizer(*si.ast, opts=device.compiler.compiler_opts)
lin = Linearizer(*si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)

View File

@@ -48,7 +48,7 @@ if __name__ == "__main__":
for ast in ast_dedup:
k = Device["CLANG"].get_linearizer(*ast)
k.linearize()
src = Device["CLANG"].compiler.compiler_opts.renderer(to_function_name(k.name), k.uops)
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
srcs[ast] = (k.name, src)
print("functions:", len(srcs))
used_buffers = dedup(flatten([si.bufs for si in sched]))

View File

@@ -25,7 +25,7 @@ if __name__ == '__main__':
for i, ast_str in enumerate(ast_strs):
print(f"optimizing {i}/{len(ast_strs)}\nast={ast_str}")
lin = ast_str_to_lin(ast_str, opts=device.compiler.compiler_opts)
lin = ast_str_to_lin(ast_str, opts=device.renderer)
rawbufs = bufs_from_lin(lin)
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))

View File

@@ -32,7 +32,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
compile_hip(code)
return (time.perf_counter() - st) * 1000
tinygrad_tm = min([time_compile(Device[Device.DEFAULT].compiler.compiler_opts.renderer(f"test{i}", lin.uops)) for i in range(10)])
tinygrad_tm = min([time_compile(Device[Device.DEFAULT].renderer.render(f"test{i}", lin.uops)) for i in range(10)])
ref_tm = min([time_compile(reference.format(name=f"test{i}")) for i in range(10)])
print(f"tinygrad {tinygrad_tm:6.2f} ms")
print(f"reference {ref_tm:6.2f} ms")

View File

@@ -15,7 +15,7 @@ if __name__ == "__main__":
beam_won, tested = 0, 0
for num, ast in enumerate(ast_strs[:test_n]):
def new_lin(): return ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
def new_lin(): return ast_str_to_lin(ast, opts=dev.renderer)
k = new_lin()
# k.required_optimizations()

View File

@@ -23,12 +23,12 @@ if __name__ == "__main__":
average_tm_cuda, average_tm_nv = 0, 0
for num,ast in enumerate(ast_strs):
# cuda compile
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
culin = ast_str_to_lin(ast, opts=cudev.renderer)
culin.hand_coded_optimizations()
cuda_prg = cudev.to_runner(culin)
cubufs = bufs_from_lin(culin)
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
nvlin = ast_str_to_lin(ast, opts=nvdev.renderer)
nvlin.hand_coded_optimizations()
nv_prg = nvdev.to_runner(nvlin)
nvbufs = bufs_from_lin(nvlin)

View File

@@ -26,7 +26,7 @@ if __name__ == "__main__":
average_tm_cuda, average_tm_ptx = 0, 0
for num,ast in enumerate(ast_strs):
# cuda compile
lin = ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
lin = ast_str_to_lin(ast, opts=dev.renderer)
lin.hand_coded_optimizations()
cuda_prg = dev.to_runner(lin)

View File

@@ -7,7 +7,7 @@ class TestDeviceSpeed(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dev = Device[Device.DEFAULT]
cls.empty = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", UOpGraph())
cls.empty = Device[Device.DEFAULT].renderer.render("test", UOpGraph())
def test_empty_compile(self):
with Timing("compiler "):

View File

@@ -141,7 +141,7 @@ class TestLinearizer(unittest.TestCase):
assert num_ops <= 1, "more alu uops than needed"
def test_reduce_upcast(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
if not Device[Device.DEFAULT].renderer.supports_float4:
self.skipTest("device does not support upcast")
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
@@ -157,7 +157,7 @@ class TestLinearizer(unittest.TestCase):
assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4)
def test_upcast_with_locals(self):
if not (opts:=Device[Device.DEFAULT].compiler.compiler_opts).has_local or not opts.has_shared or not opts.supports_float4:
if not (opts:=Device[Device.DEFAULT].renderer).has_local or not opts.has_shared or not opts.supports_float4:
self.skipTest("device does not support upcasted reduce with locals")
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
@@ -218,16 +218,16 @@ class TestLinearizer(unittest.TestCase):
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
def test_tensor_cores(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
def test_tensor_cores_padded(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
@@ -251,9 +251,9 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "RHIP", "RHIP is really slow here")
def test_tensor_cores_multi_reduce(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None
@@ -358,7 +358,7 @@ class TestLinearizer(unittest.TestCase):
helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255), max_ops=0)
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k):
@@ -567,7 +567,7 @@ class TestHandCodedOpts(unittest.TestCase):
assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49
def test_matvec(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local:
if not Device[Device.DEFAULT].renderer.has_local:
self.skipTest("Only devices with locals")
N = 128
a = Tensor.rand(1, N).realize()
@@ -618,7 +618,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
class TestKernelOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
N = 128
@@ -664,7 +664,7 @@ class TestKernelOpts(unittest.TestCase):
])
def test_matmul(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
N = 128
@@ -694,7 +694,7 @@ class TestKernelOpts(unittest.TestCase):
])
def test_double_reduce(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
N = 128
@@ -721,7 +721,7 @@ class TestKernelOpts(unittest.TestCase):
])
def test_invalid_tensor_core_extra_opts(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
@@ -742,25 +742,25 @@ class TestKernelOpts(unittest.TestCase):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
def test_buf_index_not_found_tensor_core(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Linearizer(ast, opts=Device[Device.DEFAULT].compiler.compiler_opts)
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
def test_tensor_core_opts(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
N = 128
Tensor.manual_seed(1552)
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
@@ -889,7 +889,7 @@ class TestKernelOpts(unittest.TestCase):
])
def test_color_shapes_with_local(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared:
self.skipTest("Only Compiled uses linearizer with locals and shared")
N = 32
@@ -959,7 +959,7 @@ class TestLinearizerHelper(unittest.TestCase):
assert expand_idxs(idxs) == (uidx0, NumNode(0), uidx1)
class TestLinearizerUOptimize(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "device doesn't support float4")
def test_grouped_store_phis(self):
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
@@ -973,7 +973,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "device doesn't support float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
@@ -986,8 +986,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
def test_grouped_store_locals_and_globals(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared or \
not Device[Device.DEFAULT].renderer.supports_float4:
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
@@ -1011,8 +1011,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
def test_grouped_store_local_only(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.has_shared or \
not Device[Device.DEFAULT].renderer.supports_float4:
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
@@ -1031,7 +1031,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
assert stores[1].vin[-1].dtype == dtypes.float
def test_skip_unmatching_upcasts(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.supports_float4:
self.skipTest("Needs locals and float4")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [
@@ -1047,7 +1047,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
def test_skip_unmatching_upcasts_with_gep(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
if not Device[Device.DEFAULT].renderer.has_local or not Device[Device.DEFAULT].renderer.supports_float4:
self.skipTest("Needs locals and float4")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),

View File

@@ -12,8 +12,8 @@ from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.helpers import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.compiler_opts.renderer("test", uops)
has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local
src = Device[Device.DEFAULT].renderer.render("test", uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops))
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
from tinygrad.device import Device, CompilerOptions
from tinygrad.device import Device
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker
@@ -70,9 +71,8 @@ class LocalBuffer(NamedTuple):
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class Kernel:
def __init__(self, *ast:LazyOp, opts:Optional[CompilerOptions]=None):
self.opts = opts if opts is not None else (device.compiler.compiler_opts if (device:=Device[Device.DEFAULT]).compiler is not None else
CompilerOptions(Device.DEFAULT))
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
self.ast = ast

View File

@@ -5,11 +5,12 @@ from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.features.image import to_image_idx
from tinygrad.renderer import Program
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
@@ -441,3 +442,12 @@ class Linearizer(Kernel):
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)]
cache[x] = ret
return ret
def to_program(self) -> Program:
self.linearize()
info = get_lazyop_info(self.ast[0])
ops, mem = self.uops.flops_mem()
run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(self.name, self.opts.render(to_function_name(self.name), self.uops), self.opts.device,
self.global_size, self.local_size, self.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))

View File

@@ -2,13 +2,13 @@ from __future__ import annotations
import multiprocessing
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, ClassVar, Callable, Any
from typing import TYPE_CHECKING, List, Optional, Dict, Tuple, Any
import importlib, inspect, functools, pathlib, os, ctypes
from tinygrad.helpers import prod, getenv, all_int, to_function_name, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.helpers import getenv, all_int, diskcache_get, diskcache_put, DEBUG,BEAM,NOOPT, GlobalCounters, flat_mv, from_mv
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import DType, ImageDType
from tinygrad.ops import LazyOp, get_lazyop_info
from tinygrad.codegen.uops import UOpGraph
from tinygrad.ops import LazyOp
from tinygrad.renderer import Renderer, Program
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
@@ -181,63 +181,7 @@ class Runner:
# **************** for Compiled Devices ****************
@dataclass(frozen=True)
class Program:
name:str
src:str
dname:str
global_size:Optional[List[int]]=None
local_size:Optional[List[int]]=None
uops:Optional[UOpGraph]=None
op_estimate:sint=0
mem_estimate:sint=0
@functools.cached_property
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
@functools.cached_property
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
@functools.cached_property
def outcount(self) -> int: return sum(x[1] for x in self.globals)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
def launch_dims(self, var_vals:Dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
@dataclass(frozen=True)
class CompilerOptions:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
renderer: Callable = fake_renderer
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
k.linearize()
info = get_lazyop_info(k.ast[0])
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(k.name, self.renderer(to_function_name(k.name), k.uops),
override_device if override_device else self.device,
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
class Compiler:
compiler_opts: ClassVar[CompilerOptions]
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile_cached(self, src:str) -> bytes:
@@ -276,24 +220,25 @@ class CompiledRunner(Runner):
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
self.renderer = renderer if renderer else Renderer()
def synchronize(self): pass # override this in your device
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.compiler_opts.to_program(k, override_device=self.dname))
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(replace(k.to_program(), dname=self.dname))
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
if DEBUG >= 3:
from tinygrad.features.graph import print_tree
for op in ast: print_tree(op)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(*ast, opts=self.compiler.compiler_opts)
k = Linearizer(*ast, opts=self.renderer)
k.required_optimizations()
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
kb, k_opt = Linearizer(*ast, opts=self.compiler.compiler_opts), k
kb, k_opt = Linearizer(*ast, opts=self.renderer), k
kb.required_optimizations()
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
@@ -301,7 +246,7 @@ class Compiled:
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
if used_tensor_cores:
lins.append(("hc", Linearizer(*ast, opts=self.compiler.compiler_opts)))
lins.append(("hc", Linearizer(*ast, opts=self.renderer)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))

View File

@@ -51,7 +51,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
try:
x[1].linearize()
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
p = compiler.compiler_opts.to_program(x[1])
p = x[1].to_program()
st = time.perf_counter()
prog = compiler.compile(p.src)
et = time.perf_counter() - st
@@ -174,7 +174,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
p = dev.compiler.compiler_opts.to_program(lin)
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

View File

@@ -0,0 +1,49 @@
from typing import Optional, List, Tuple, Dict
import functools
from dataclasses import dataclass
from tinygrad.helpers import to_function_name
from tinygrad.codegen.uops import UOpGraph
from tinygrad.shape.symbolic import sym_infer, sint, Variable
@dataclass(frozen=True)
class Program:
name:str
src:str
dname:str
global_size:Optional[List[int]]=None
local_size:Optional[List[int]]=None
uops:Optional[UOpGraph]=None
op_estimate:sint=0
mem_estimate:sint=0
@functools.cached_property
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
@functools.cached_property
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
@functools.cached_property
def outcount(self) -> int: return sum(x[1] for x in self.globals)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
def launch_dims(self, var_vals:Dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
class Renderer:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")

View File

@@ -1,15 +1,16 @@
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
import functools, struct, copy
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct, copy
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
from tinygrad.renderer import Renderer
def render_val(x, dtype):
if dtypes.is_float(dtype):
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
elif dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
@@ -33,191 +34,16 @@ def ptr_ar(root, uops):
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
class AssemblyLanguage(NamedTuple):
kernel_prefix: str = ""
barrier: str = ""
load_global: bool = False
label_prefix: str = ""
gid: List[str] = []
gdim: List[str] = []
lid: List[str] = []
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
asm_for_op: Dict[Op, Callable[...,str]] = {}
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]: raise NotImplementedError()
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
def mem_type(self, dtype) -> str: raise NotImplementedError()
def uops_to_asm(lang:AssemblyLanguage, function_name:str, _uops:UOpGraph) -> str:
# editing the uops breaks beam search
uops = copy.deepcopy(_uops)
kernel:List[str] = []
bufs = []
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
c[prefix] += 1
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
return out
return lang.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
elif uop is UOps.ENDLOOP:
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{lang.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(nm, ssa('dat', u, lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return lang.render_kernel(kernel, function_name, bufs, c.items())
class PTXLanguage(AssemblyLanguage):
# language options
kernel_prefix = """.version VERSION
.target TARGET
.address_size 64
@@ -229,7 +55,7 @@ class PTXLanguage(AssemblyLanguage):
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
asm_for_op = {
asm_for_op: Dict[Op, Callable] = {
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
@@ -245,13 +71,14 @@ class PTXLanguage(AssemblyLanguage):
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
TernaryOps.WHERE]
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
types = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
const_requires_mov = [dtypes.half, dtypes.bool]
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
@@ -270,7 +97,7 @@ class PTXLanguage(AssemblyLanguage):
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
@@ -291,4 +118,160 @@ class PTXLanguage(AssemblyLanguage):
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
"\n}")
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
def render(self, name:str, _uops:UOpGraph) -> str:
# editing the uops breaks beam search
uops = copy.deepcopy(_uops)
kernel:List[str] = []
bufs = []
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in self.asm_for_op.keys() if op not in self.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
c[prefix] += 1
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in self.const_requires_mov:
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
return out
return self.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDLOOP:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
else:
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"
if self.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return self.render_kernel(kernel, name, bufs, c.items())

View File

@@ -1,13 +1,14 @@
from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
import math
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
import os, math
from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import strip_parens, getenv, prod
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph
from tinygrad.renderer import Renderer
class CStyleLanguage(NamedTuple):
class CStyleLanguage(Renderer):
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
@@ -17,8 +18,6 @@ class CStyleLanguage(NamedTuple):
arg_int_prefix: str = "const int"
barrier: str = ""
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
float4: Optional[str] = None
uses_vload: bool = False
@@ -88,100 +87,107 @@ class CStyleLanguage(NamedTuple):
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str:
kernel = []
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
depth = 1
def kk(s): kernel.append(" "*depth+s)
def render(self, name:str, uops:UOpGraph) -> str:
kernel = []
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
c[prefix] += 1
return ret
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
c[prefix] += 1
return ret
child_count = Counter(v for ru in uops for v in ru.vin)
child_count = Counter(v for ru in uops for v in ru.vin)
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
depth += 1
elif uop is UOps.BARRIER: kk(lang.barrier)
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else: operands = [r[v] for v in vin]
val = lang.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa('precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = lang.render_cast([precast], dtype, bitcast=True)
else:
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
elif uop is UOps.BARRIER: kk(self.barrier)
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
else: operands = [r[v] for v in vin]
val = self.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa('precast')
kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = self.render_cast([precast], dtype, bitcast=True)
else:
val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(self.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, uops)
return self.render_kernel(name, kernel, bufs, uops)
class ClangLanguage(CStyleLanguage):
class ClangRenderer(CStyleLanguage):
device = "CLANG"
supports_float4 = False
has_local = False
# language options
buffer_suffix = " restrict"
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
def ClangRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(ClangLanguage(), name, uops)
class OpenCLLanguage(CStyleLanguage):
class OpenCLRenderer(CStyleLanguage):
device = "GPU"
# language options
kernel_prefix = "__kernel "
buffer_prefix = "__global "
smem_align = "__attribute__ ((aligned (16))) "
@@ -197,9 +203,13 @@ class OpenCLLanguage(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def OpenCLRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(OpenCLLanguage(), name, uops)
class MetalLanguage(CStyleLanguage):
class MetalRenderer(CStyleLanguage):
device = "METAL"
has_tensor_cores=os.uname().machine == "arm64"
shared_max=32768
# language options
kernel_prefix = "kernel "
buffer_prefix = "device "
smem_prefix = "threadgroup "
@@ -227,7 +237,6 @@ class MetalLanguage(CStyleLanguage):
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def MetalRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(MetalLanguage(), name, uops)
code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
@@ -240,7 +249,15 @@ def _make_cuda_dtype(base_type, name, cnt):
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
class CUDALanguage(CStyleLanguage):
class CUDARenderer(CStyleLanguage):
device = "CUDA"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
# language options
kernel_prefix = "extern \"C\" __global__ "
smem_prefix = "__shared__ "
smem_prefix_for_cast = False
@@ -271,7 +288,6 @@ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}
return c;}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
def CUDARenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(CUDALanguage(), name, uops)
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
@@ -295,7 +311,12 @@ def _make_hip_dtype(base_type, name, cnt):
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
class HIPLanguage(CStyleLanguage):
class HIPRenderer(CStyleLanguage):
device = "HSA"
has_tensor_cores = True
shared_max = 65536
# language options
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
@@ -357,5 +378,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
def HIPRenderer(name:str, uops:UOpGraph) -> str: return uops_to_cstyle(HIPLanguage(), name, uops)

View File

@@ -4,6 +4,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOpGraph
from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
@@ -65,88 +66,95 @@ def cast(bb, val, input_type, output_type, bitcast=False):
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
class LLVMRenderer(Renderer):
device = "LLVM"
supports_float4=False
has_local=False
has_shared=False
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
def render(self, name:str, uops:UOpGraph) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
if len(vin) > 3:
with bb[-1].if_then(lvars[vin[3]]): store_op()
else: store_op()
elif uop is UOps.ENDLOOP:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(vin) > 2:
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
if len(vin) > 3:
with bb[-1].if_then(lvars[vin[3]]):
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
lvars[u] = val
elif uop is UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop is UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
elif uop is UOps.ENDLOOP:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
bb[-1].ret_void()
return str(module)
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(vin) > 2:
aug_idx = bb[-1].select(lvars[vin[2]], lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = bb[-1].select(lvars[vin[2]], val, lvars[vin[3]])
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
lvars[u] = val
elif uop is UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop is UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
bb[-1].ret_void()
return str(module)

View File

@@ -6,8 +6,8 @@ from tinygrad.device import Buffer, Device, CompiledRunner
from tinygrad.engine.realize import ExecItem
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_clang import ClangProgram
from tinygrad.renderer.cstyle import ClangLanguage
render_dtype = ClangLanguage().render_dtype
from tinygrad.renderer.cstyle import ClangRenderer
render_dtype = ClangRenderer().render_dtype
class ClangGraph(GraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import Tuple, List, Any, cast
import os, fcntl, ctypes, functools, re, pathlib, mmap, struct, errno, subprocess, time
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip
from tinygrad.runtime.ops_hsa import HSACompiler
import tinygrad.runtime.autogen.kfd as kfd
import tinygrad.runtime.autogen.hsa as hsa
import tinygrad.runtime.autogen.amd_gpu as amd_gpu
@@ -68,7 +69,6 @@ def create_sdma_packets():
sdma_pkts = create_sdma_packets()
class AMDCompiler(Compiler):
compiler_opts = CompilerOptions("AMD", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer)
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_hip_{self.arch}")
@@ -583,7 +583,8 @@ class AMDDevice(Compiled):
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
super().__init__(device, AMDAllocator(self), HIPRenderer(), HSACompiler(self.arch),
functools.partial(AMDProgram, self),
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
def synchronize(self):

View File

@@ -1,10 +1,9 @@
import ctypes, subprocess, pathlib, tempfile
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
from tinygrad.device import Compiled, Compiler, MallocAllocator
from tinygrad.helpers import cpu_time_execution
from tinygrad.renderer.cstyle import ClangRenderer
class ClangCompiler(Compiler):
compiler_opts = CompilerOptions("CLANG", supports_float4=False, has_local=False, renderer=ClangRenderer)
def compile(self, src:str) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
@@ -25,4 +24,4 @@ class ClangProgram:
class ClangDevice(Compiled):
def __init__(self, device:str):
from tinygrad.runtime.graph.clang import ClangGraph
super().__init__(device, MallocAllocator, ClangCompiler("compile_clang"), ClangProgram, ClangGraph)
super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler("compile_clang"), ClangProgram, ClangGraph)

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
from pathlib import Path
from dataclasses import replace
from typing import Tuple, Optional, List
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator, MallocAllocator
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator, MallocAllocator
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.assembly import PTXRenderer
if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401
@@ -53,21 +52,17 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
class PTXCompiler(Compiler):
compiler_opts = CompilerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
shared_max=49152, renderer=PTXRenderer)
def __init__(self, arch:str):
self.arch = arch
self.version = "7.8" if arch >= "sm_89" else "7.5"
PTXCompiler.compiler_opts = replace(PTXCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
#PTXCompiler.compiler_opts = replace(PTXCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
super().__init__(f"compile_ptx_{self.arch}")
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
class CUDACompiler(Compiler):
compiler_opts = CompilerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
shared_max=49152, renderer=CUDARenderer)
def __init__(self, arch:str):
self.arch = arch
CUDACompiler.compiler_opts = replace(CUDACompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
#CUDACompiler.compiler_opts = replace(CUDACompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
@@ -176,6 +171,7 @@ class CUDADevice(Compiled):
from tinygrad.runtime.graph.cuda import CUDAGraph
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
PTXRenderer(self.arch) if getenv("PTX") else CUDARenderer(self.arch),
PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)

View File

@@ -35,7 +35,7 @@ class DiskDevice(Compiled):
def __init__(self, device:str):
self.size: Optional[int] = None
self.count = 0
super().__init__(device, DiskAllocator(self), None, None)
super().__init__(device, DiskAllocator(self), None, None, None)
def _might_open(self, size):
self.count += 1
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"

View File

@@ -4,7 +4,7 @@ import ctypes, functools, hashlib
import tinygrad.runtime.autogen.opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompilerOptions
from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
@@ -14,7 +14,6 @@ def check(status):
def checked(ret, status): return (check(status.value), ret)[1]
class CLCompiler(Compiler):
compiler_opts = CompilerOptions("GPU", renderer=OpenCLRenderer)
def __init__(self, device:CLDevice, compile_key:str):
self.device = device
super().__init__(f"compile_cl_{compile_key}")
@@ -96,7 +95,7 @@ class CLDevice(Compiled):
self.pending_copyin: List[memoryview] = []
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
super().__init__(device, CLAllocator(self), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
super().__init__(device, CLAllocator(self), OpenCLRenderer(), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
def synchronize(self):
check(cl.clFinish(self.queue))
self.pending_copyin.clear()

View File

@@ -3,7 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
from tinygrad.device import Compiled, Compiler, CompilerOptions, BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
from tinygrad.runtime.driver.hip_comgr import compile_hip
@@ -42,7 +42,6 @@ class HSAProfiler:
Profiler = HSAProfiler()
class HSACompiler(Compiler):
compiler_opts = CompilerOptions("HSA", has_tensor_cores=True, shared_max=65536, renderer=HIPRenderer)
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_hip_{self.arch}")
@@ -219,7 +218,7 @@ class HSADevice(Compiled):
self.reusable_signals: List[hsa.hsa_signal_t] = []
from tinygrad.runtime.graph.hsa import HSAGraph
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
# Finish init: preallocate some signals + space for kernargs
self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
import ctypes, functools
from typing import Tuple
from tinygrad.device import Compiled, Compiler, CompilerOptions, MallocAllocator
from tinygrad.device import Compiled, Compiler, MallocAllocator
from tinygrad.helpers import DEBUG, cpu_time_execution
from tinygrad.renderer.llvmir import uops_to_llvm_ir
from tinygrad.renderer.llvmir import LLVMRenderer
import llvmlite.binding as llvm
class LLVMCompiler(Compiler):
compiler_opts = CompilerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False, renderer=uops_to_llvm_ir)
def __init__(self, device:LLVMDevice):
self.device = device
super().__init__("compile_llvm")
@@ -43,4 +42,4 @@ class LLVMDevice(Compiled):
backing_mod = llvm.parse_assembly(str())
backing_mod.triple = llvm.get_process_triple()
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
super().__init__(device, MallocAllocator, LLVMCompiler(self), functools.partial(LLVMProgram, self))
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))

View File

@@ -3,7 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Set, Any, Tuple, Optional
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
def wait_check(cbuf: Any):
@@ -12,7 +12,6 @@ def wait_check(cbuf: Any):
raise RuntimeError(error)
class MetalCompiler(Compiler):
compiler_opts = CompilerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64", shared_max=32768, renderer=MetalRenderer)
def __init__(self, device:Optional[MetalDevice]):
self.device = device
super().__init__("compile_metal")
@@ -97,7 +96,7 @@ class MetalDevice(Compiled):
self.mv_in_metal: List[memoryview] = []
self.track_cross_buffer: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)

View File

@@ -6,4 +6,4 @@ class NpyAllocator(Allocator):
def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
class NpyDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None)
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
from typing import Tuple, List, Any, cast
from dataclasses import replace
from tinygrad.device import Compiled, Compiler, CompilerOptions, LRUAllocator, BufferOptions
from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes
@@ -65,10 +64,9 @@ def nvdata64(data): return (data >> 32, data & 0xFFFFFFFF)
def nvdata64_le(data): return (data & 0xFFFFFFFF, data >> 32)
class NVCompiler(Compiler):
compiler_opts = CompilerOptions("NV", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152, renderer=CUDARenderer)
def __init__(self, arch:str):
self.arch = arch
NVCompiler.compiler_opts = replace(NVCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
#NVCompiler.compiler_opts = replace(NVCompiler.compiler_opts, has_tensor_cores=int(arch[3:]) >= 80)
cuda_check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
@@ -493,7 +491,7 @@ class NVDevice(Compiled):
self.arch: str = 'sm_89' # TODO: fix
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self),
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), NVCompiler(self.arch), functools.partial(NVProgram, self),
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
self._cmdq_setup_compute_gpfifo()

View File

@@ -5,9 +5,10 @@ from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, CompilerOptions, Allocator
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
from tinygrad.renderer import Renderer
def _load(m, i):
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
@@ -177,15 +178,18 @@ class PythonProgram:
i += 1
return time.perf_counter() - st
def PythonRenderer(name:str, uops:UOpGraph) -> str:
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
class PythonRenderer(Renderer):
device = "PYTHON"
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.has_tensor_cores = "METAL", True
if getenv("EMULATE_HSA"): self.device, self.has_tensor_cores = "HSA", True
if getenv("EMULATE_CUDA"): self.device, self.has_tensor_cores = "CUDA", True
def render(self, name:str, uops:UOpGraph) -> str:
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
class PythonCompiler(Compiler):
compiler_opts = CompilerOptions("METAL", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_METAL") else \
(CompilerOptions("HSA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_HSA") else \
(CompilerOptions("CUDA", has_tensor_cores=True, renderer=PythonRenderer) if getenv("EMULATE_CUDA") else \
CompilerOptions("PYTHON", renderer=PythonRenderer)))
def compile(self, src:str) -> bytes: return base64.b64decode(src)
class PythonAllocator(Allocator):
@@ -195,4 +199,4 @@ class PythonAllocator(Allocator):
class PythonDevice(Compiled):
def __init__(self, device:str):
super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram)
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)

View File

@@ -1,5 +1,6 @@
import ctypes
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.ops_hsa import HSACompiler
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
@@ -14,4 +15,4 @@ class RHIPProgram:
class RHIPDevice(Compiled):
def __init__(self, device:str=""):
self.device = int(device.split(":")[1]) if ":" in device else 0
super().__init__(device, MallocAllocator, HSACompiler("gfx1100"), RHIPProgram)
super().__init__(device, MallocAllocator, HIPRenderer(), HSACompiler("gfx1100"), RHIPProgram)