mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
cleanup stale examples/extra (#13764)
* cleanup stale files * examples * move those back * old * delete more
This commit is contained in:
@@ -1,189 +0,0 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.opt.kernel import Ops, MemOp, UOp
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.uop.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: Ops
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyLanguage:
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
#TODO: these should be global vars
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
idx = args.idx*args.memory_dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
|
||||
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
#TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == Ops.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == Ops.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == Ops.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == Ops.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == Ops.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == Ops.DEFINE_REG:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
|
||||
elif uop == Ops.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == Ops.CONST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == Ops.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == Ops.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
return global_size, local_size
|
||||
@@ -1,177 +0,0 @@
|
||||
import struct
|
||||
from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.opt.kernel import Ops, UOp
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
def compute_offsets(total):
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
||||
|
||||
#NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||
|
||||
class ARM64Language(AssemblyLanguage): pass
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[Ops] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
if value.__class__ is not float and abs(value) > abs(65535):
|
||||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == 's':
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
ins.append(f"ldr {reg}, [sp, 16]")
|
||||
else:
|
||||
ins.append(f"mov {reg}, #{value}")
|
||||
|
||||
# Get variables intervals
|
||||
live_range:Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
|
||||
mem_vars:Dict[str, int] = {}
|
||||
rtor:Dict[str, str] = {}
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if not available_regs:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ['s0', 's1', 's2']
|
||||
temp_ints = ['x12', 'x13', 'x16']
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == Ops.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
mov_imm(1.0, 's1')
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
if rtor[out.nm][0] == 'x':
|
||||
mov_imm(0, 'x14')
|
||||
mov_imm(1, 'x15')
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == Ops.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
else:
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
elif arg == BinaryOps.MOD:
|
||||
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == Ops.STORE:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == Ops.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == Ops.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"b.lt loop_{arg[1]}")
|
||||
prev_uop = uop
|
||||
# store regs into memory if needed
|
||||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||
|
||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.opt.kernel import Ops, UOp
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
||||
|
||||
def render_cast(ins, inp, out):
|
||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
||||
elif out.dtype == dtypes.bool:
|
||||
if inp.dtype == dtypes.bool:
|
||||
ins.append(f"mov.pred {out}, {inp};")
|
||||
else:
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
||||
else:
|
||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
def specialize_to_ptx(lang, function_name):
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == Ops.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == Ops.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == Ops.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
||||
if arg == TernaryOps.WHERE:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
reg = vin[0]
|
||||
else:
|
||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == Ops.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else: prereg = vin[1]
|
||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == Ops.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return '\n'.join(ins)
|
||||
|
||||
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
||||
@@ -1,203 +0,0 @@
|
||||
import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.opt.kernel import Ops
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH
|
||||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
boilerplate_start = """
|
||||
.global _start
|
||||
_start:
|
||||
.rodata
|
||||
.align 0x10
|
||||
.global code.kd
|
||||
.type code.kd,STT_OBJECT
|
||||
.amdhsa_kernel code"""
|
||||
|
||||
code_start = """.end_amdhsa_kernel
|
||||
.text
|
||||
code:
|
||||
"""
|
||||
|
||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
class RDNACodegen(AssemblyCodegen):
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt"}
|
||||
|
||||
pend_regs:Set[Register] = set()
|
||||
rtor:Dict[Register, str] = {}
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
#print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
#print("clear")
|
||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == Ops.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
||||
s_cnt += align
|
||||
else:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args)*8
|
||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == Ops.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == Ops.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == Ops.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == Ops.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == Ops.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i,tins in enumerate(ins):
|
||||
if tins in seen: continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i+1:]):
|
||||
if gins in seen: continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
||||
if r0[0]%2 == r1[0]%2: continue
|
||||
if r0[1]%2 == r1[1]%2: continue
|
||||
if r0[2]%2 == r1[2]%2: continue
|
||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
if tins not in seen:
|
||||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
|
||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
|
||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
||||
return asm
|
||||
@@ -1,23 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg = CUDAProgram("test", """
|
||||
.version 7.8
|
||||
.target sm_86
|
||||
.address_size 64
|
||||
.visible .entry test(.param .u64 x) {
|
||||
.reg .b32 %r<2>;
|
||||
.reg .b64 %rd<3>;
|
||||
|
||||
ld.param.u64 %rd1, [x];
|
||||
cvta.to.global.u64 %rd2, %rd1;
|
||||
mov.u32 %r1, 0x40000000; // 2.0 in float
|
||||
st.global.u32 [%rd2], %r1;
|
||||
ret;
|
||||
}""", binary=True)
|
||||
prg([1], [1], test)
|
||||
print(test.toCPU())
|
||||
|
||||
Reference in New Issue
Block a user