mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1,5 +1,5 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.kernel import UOps, MemOp, UOp
|
||||
from tinygrad.codegen.kernel import Ops, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
@@ -23,7 +23,7 @@ class Register(NamedTuple):
|
||||
return []
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
op: Ops
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
@@ -49,21 +49,21 @@ class AssemblyLanguage:
|
||||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
@@ -87,7 +87,7 @@ class AssemblyLanguage:
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
@@ -98,91 +98,91 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if uop == Ops.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == Ops.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == Ops.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == UOps.CAST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == Ops.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == Ops.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == Ops.DEFINE_ACC:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||
elif uop == UOps.SPECIAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
|
||||
elif uop == Ops.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == UOps.CONST:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.CONST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == Ops.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
|
||||
@@ -3,7 +3,7 @@ from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad.codegen.kernel import Ops, UOp
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
@@ -19,7 +19,7 @@ class ARM64Language(AssemblyLanguage): pass
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[UOps] = None
|
||||
prev_uop:Optional[Ops] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
@@ -81,7 +81,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
@@ -90,7 +90,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
@@ -102,7 +102,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
@@ -136,7 +136,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
@@ -146,20 +146,20 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
if prev_uop == Ops.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
elif uop == Ops.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad.codegen.kernel import Ops, UOp
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
@@ -37,11 +37,11 @@ def specialize_to_ptx(lang, function_name):
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
if uop == Ops.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
elif uop == Ops.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
@@ -51,7 +51,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
@@ -64,7 +64,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
@@ -74,7 +74,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
@@ -85,11 +85,11 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
|
||||
@@ -2,7 +2,7 @@ import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.kernel import UOps
|
||||
from tinygrad.codegen.kernel import Ops
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||
|
||||
@@ -61,7 +61,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if uop == Ops.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
@@ -86,7 +86,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
@@ -106,7 +106,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == UOps.CONST:
|
||||
elif uop == Ops.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
@@ -114,7 +114,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
@@ -127,7 +127,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
@@ -135,13 +135,13 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
|
||||
Reference in New Issue
Block a user