cleanup stale examples/extra (#13764)

* cleanup stale files

* examples

* move those back

* old

* delete more
This commit is contained in:
George Hotz
2025-12-19 16:27:37 -04:00
committed by GitHub
parent 80b84f5267
commit df6cde8a00
45 changed files with 0 additions and 5039 deletions

View File

@@ -1,189 +0,0 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.opt.kernel import Ops, MemOp, UOp
from tinygrad.uop.ops import BinaryOps, UnaryOps
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import DEBUG
from tinygrad.uop.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
class Register(NamedTuple):
nm:str
dtype:DType
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes.float.vec(4):
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: Ops
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyLanguage:
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
#TODO: these should be global vars
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
ins: List[AssemblyInstruction] = []
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes.float.vec(4):
for off in range(4):
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
self.cnts[(dtype, scalar)] += 1
return ret
def render_numnode(self, b) -> Register:
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return self.tor[key]
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in self.tor:
#if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return self.tor[key]
def render_cast(self, a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
return self.tor[key]
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(self, args):
assert isinstance(args, MemOp)
idx = args.idx*args.memory_dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = cast(int, nums[0])
reg = idx.render(self.render_ops, self)
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
return reg, None, off
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
#TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
for u in uops:
uop,dtype,vin,args,_ = u
if uop == Ops.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == Ops.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
elif uop == Ops.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif args[1] == 'local':
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
elif uop == Ops.CAST:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(u, dtype)
for i,sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == Ops.ALU:
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((u, "rcp"))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((u, "2pi"))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == Ops.DEFINE_REG:
reg = lang.newreg(u, dtype=dtype)
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
elif uop == Ops.SPECIAL:
lang.tor[u] = lang.tor[args]
elif uop == Ops.CONST:
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
elif uop == Ops.LOAD:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == Ops.STORE:
if args is None:
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
else:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if DEBUG >= 4:
for tins in lang.ins: print(tins)
return global_size, local_size

View File

@@ -1,177 +0,0 @@
import struct
from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad import dtypes
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.opt.kernel import Ops, UOp
from tinygrad.helpers import CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
class ARM64Language(AssemblyLanguage): pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop:Optional[Ops] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(12))]
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
def mov_imm(value, reg):
# Manually move value into reg if value can't fit
if value.__class__ is not float and abs(value) > abs(65535):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append("str x15, [sp, 16]")
ins.append(f"ldr {reg}, [sp, 16]")
else:
ins.append(f"mov {reg}, #{value}")
# Get variables intervals
live_range:Dict[str, List[int]] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
mem_vars:Dict[str, int] = {}
rtor:Dict[str, str] = {}
def allocate_regs(mvars):
nonlocal var_size
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if not available_regs:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x12', 'x13', 'x16']
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == Ops.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
ins.append(f"mov {rtor[out.nm]}, x15")
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == Ops.CAST:
if arg == BinaryOps.CMPLT:
if rtor[out.nm][0] == 's':
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
if rtor[out.nm][0] == 'x':
mov_imm(0, 'x14')
mov_imm(1, 'x15')
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == Ops.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i,k in enumerate(save_regs,1):
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
ins.append("stp x29, x30, [sp, #0]!")
ins.append("mov x29, sp")
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
ins.append(alu[arg])
ins.append(f"fmov {rtor[out.nm]}, s0")
ins.append("mov sp, x29")
ins.append("ldp x29, x30, [sp], #0")
for i,k in enumerate(save_regs,1):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
elif arg == BinaryOps.MOD:
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == Ops.STORE:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
elif uop == Ops.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == Ops.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == Ops.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == Ops.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
ins.append(f"b.lt loop_{arg[1]}")
prev_uop = uop
# store regs into memory if needed
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True

View File

@@ -1,105 +0,0 @@
from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.opt.kernel import Ops, UOp
from tinygrad import dtypes
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
def render_cast(ins, inp, out):
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
elif out.dtype == dtypes.bool:
if inp.dtype == dtypes.bool:
ins.append(f"mov.pred {out}, {inp};")
else:
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
else:
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage(AssemblyLanguage):
supports_constant_folding: bool = True
def specialize_to_ptx(lang, function_name):
param_cnt = 0
ins = []
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
for uop, out, vin, arg in lang.ins:
if uop == Ops.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == Ops.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == Ops.SPECIAL:
if arg.startswith('data'):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == Ops.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
if arg == TernaryOps.WHERE:
if vin[0].dtype == dtypes.bool:
reg = vin[0]
else:
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
vin = vin[1:] + [reg]
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
reg = lang.newreg((out, dt[0]), dtype=dt[1])
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
render_cast(ins, reg, out)
else:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == Ops.STORE:
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
render_cast(ins, vin[1], prereg)
else: prereg = vin[1]
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
render_cast(ins, prereg, reg)
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
else:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == Ops.CAST:
render_cast(ins, vin[0], out)
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == Ops.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
ins = ins_prefix + ins
ins += ["ret;", "}"]
return '\n'.join(ins)
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
lang = PTXLanguage()
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True

View File

@@ -1,203 +0,0 @@
import yaml
from typing import Tuple, Set, Dict
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.opt.kernel import Ops
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH
# ugh, is this really needed?
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()
boilerplate_start = """
.global _start
_start:
.rodata
.align 0x10
.global code.kd
.type code.kd,STT_OBJECT
.amdhsa_kernel code"""
code_start = """.end_amdhsa_kernel
.text
code:
"""
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
# RDNA3 is actually a SIMD machine!
class RDNACodegen(AssemblyCodegen):
supports_float4: bool = True
supports_float4_alu: bool = True
supports_load3: bool = True
sin_is_sin2pi: bool = True
no_div: bool = True
def specialize(self, asm) -> Tuple[str, str]:
args = []
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
BinaryOps.CMPLT: "cmp_lt"}
pend_regs:Set[Register] = set()
rtor:Dict[Register, str] = {}
def reg_in(x):
nonlocal pend_regs
#print("reg_in", x, rtor[x], pend_regs)
if x in pend_regs:
#print("clear")
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
pend_regs.clear()
return rtor[x]
def reg_out(x):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == Ops.DEFINE_REGISTER:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
align = int(arg[0][0].itemsize / 4)
if arg[0][1]:
s_cnt += s_cnt % align
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
s_cnt += align
else:
v_cnt += v_cnt % align
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
v_cnt += align
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
if arg[0][0] == dtypes.float.vec(4):
for off in range(4):
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
elif arg[0][0] == dtypes.bool:
for i in range(arg[2]):
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
else:
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
elif uop == Ops.SPECIAL:
if arg.startswith('buf'):
i = int(arg[3:])
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif arg.startswith('gid'):
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
# the docs lied, this is actually y
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
# get local size
offset = len(args)*8
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
pend_regs.clear()
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
elif uop == Ops.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if out.dtype == dtypes.float.vec(4):
for off in range(4):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
elif uop == Ops.ALU:
if arg in [BinaryOps.CMPLT]:
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
alu_arg = alu[arg]
if arg == TernaryOps.MULACC and out == vin[2]:
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes.float.vec(4):
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
elif uop == Ops.LOAD:
if out.scalar:
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
else:
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif uop == Ops.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == Ops.COND_BRANCH:
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
elif uop == Ops.CAST:
if vin[0].dtype == dtypes.bool:
if out.dtype == dtypes.float32:
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
else:
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
else:
raise NotImplementedError(uop)
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
# dual alu group
seen = set()
new_ins = []
for i,tins in enumerate(ins):
if tins in seen: continue
if tins.startswith("v_fmac_f32"):
for gins in reversed(ins[i+1:]):
if gins in seen: continue
if gins.startswith("v_fmac_f32"):
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
if r0[0]%2 == r1[0]%2: continue
if r0[1]%2 == r1[1]%2: continue
if r0[2]%2 == r1[2]%2: continue
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
seen.add(tins)
seen.add(gins)
break
if tins not in seen:
new_ins.append(tins)
ins = new_ins
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
def assemble(self, args, ins, v_cnt, s_cnt):
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
metadata = {'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
return asm

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python3
import numpy as np
from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
if __name__ == "__main__":
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
prg = CUDAProgram("test", """
.version 7.8
.target sm_86
.address_size 64
.visible .entry test(.param .u64 x) {
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
ld.param.u64 %rd1, [x];
cvta.to.global.u64 %rd2, %rd1;
mov.u32 %r1, 0x40000000; // 2.0 in float
st.global.u32 [%rd2], %r1;
ret;
}""", binary=True)
prg([1], [1], test)
print(test.toCPU())