mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Remove webgpu, back to 5k lines (#3040)
* remove webgpu * max 5000 lines
This commit is contained in:
41
extra/backends/ops_webgpu.py
Normal file
41
extra/backends/ops_webgpu.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from wgpu.utils.device import get_default_device
|
||||
from tinygrad.device import Compiled, Allocator
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import WGSLRenderer
|
||||
import wgpu
|
||||
|
||||
wgpu_device = get_default_device()
|
||||
def create_uniform(val: int) -> wgpu.GPUBuffer:
|
||||
buf = wgpu_device.create_buffer(size=4, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST)
|
||||
wgpu_device.queue.write_buffer(buf, 0, val.to_bytes(4, "little"))
|
||||
return buf
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler
|
||||
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False):
|
||||
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501
|
||||
bindings = [{"binding": i, "resource": {"buffer": create_uniform(x) if i >= len(bufs) else x, "offset": 0, "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501
|
||||
bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
|
||||
pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
|
||||
bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
|
||||
compute_pipeline = wgpu_device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
|
||||
command_encoder = wgpu_device.create_command_encoder()
|
||||
compute_pass = command_encoder.begin_compute_pass()
|
||||
compute_pass.set_pipeline(compute_pipeline)
|
||||
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
|
||||
compute_pass.dispatch_workgroups(*global_size) # x y z
|
||||
compute_pass.end()
|
||||
wgpu_device.queue.submit([command_encoder.finish()])
|
||||
|
||||
class WebGpuAllocator(Allocator):
|
||||
def _alloc(self, size: int):
|
||||
return wgpu_device.create_buffer(size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
|
||||
def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
|
||||
def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
|
||||
|
||||
class WebGpuDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64],
|
||||
global_max=[65535, 65535, 65535]), WGSLRenderer, lambda x: x, WebGPUProgram)
|
||||
131
extra/backends/triton.py
Normal file
131
extra/backends/triton.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Dict, List, Final, Callable, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.codegen.linearizer import UOp, UOps
|
||||
from triton.compiler import compile as triton_compile
|
||||
import linecache
|
||||
import math
|
||||
import re
|
||||
|
||||
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
|
||||
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
|
||||
|
||||
def next_power_of_2(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
def render_valid(valid):
|
||||
return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True'
|
||||
|
||||
#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile
|
||||
def fill_dims_for_idx(idx, dims):
|
||||
return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx
|
||||
|
||||
def get_max(var):
|
||||
if isinstance(var, int): return var
|
||||
return re.sub(r'\[(.*?)\]', '', str(var))[1:-1]
|
||||
|
||||
#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved
|
||||
def remove_single_scalar_curly_braces(ptx_code):
|
||||
return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')])
|
||||
|
||||
def render_const(args,dtype:DType):
|
||||
return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args))
|
||||
|
||||
def render_cast(x:str, dtype:DType, bitcast=False):
|
||||
return f"{x}.to({triton_dtypes[dtype]}, bitcast={bitcast})"
|
||||
|
||||
def define_scalar(local_size, dtype, args):
|
||||
if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})"
|
||||
return render_const(args,dtype)
|
||||
|
||||
def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
local_size: List[int] = []
|
||||
depth = 1
|
||||
signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
nonlocal c, r
|
||||
c[prefix] += 1
|
||||
r[u]=f"{prefix}{c[prefix]-1}"
|
||||
return r[u]
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
for v in ru.vin:
|
||||
child_count[v] += 1
|
||||
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
|
||||
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
||||
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
||||
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
||||
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
|
||||
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
||||
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
|
||||
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
|
||||
}
|
||||
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop == UOps.LOOP:
|
||||
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
|
||||
depth += 1
|
||||
elif uop == UOps.END: depth -= 1
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
val = code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
|
||||
else: kk(f"{ssa(u, 'alu')} = ({val})")
|
||||
elif uop == UOps.LOAD:
|
||||
assert dtype is not None
|
||||
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
|
||||
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
|
||||
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
||||
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
|
||||
elif uop == UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.STORE:
|
||||
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
||||
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
|
||||
r[u] = args
|
||||
elif uop == UOps.SPECIAL:
|
||||
dims.append(args[1])
|
||||
valid.append(f"{args[1]}<{get_max(args[2])}")
|
||||
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
|
||||
elif args[1].startswith("l"):
|
||||
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
||||
local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
||||
else: raise NotImplementedError(f"unimplemented: {uop}")
|
||||
|
||||
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
|
||||
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
|
||||
prg += "\n".join(kernel)
|
||||
|
||||
acc_local_size = 1
|
||||
for x in local_size: acc_local_size *= next_power_of_2(x)
|
||||
local_size = [acc_local_size] + [1] * (len(local_size) - 1)
|
||||
|
||||
if DEBUG >= 4: print(prg)
|
||||
getlines = linecache.getlines
|
||||
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
|
||||
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
|
||||
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
|
||||
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
||||
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
||||
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
||||
|
||||
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}
|
||||
Reference in New Issue
Block a user