mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 15:58:10 -05:00
now DEFINE_GLOBAL uop.arg[1] is always the same as uop.dtype, we can remove the one in arg and just use uop.dtype
189 lines
10 KiB
Python
189 lines
10 KiB
Python
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
|
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
|
from tinygrad.ops import BinaryOps, UnaryOps
|
|
from tinygrad.helpers import DType, dtypes, DEBUG
|
|
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
|
import functools
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
|
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
|
|
|
class Register(NamedTuple):
|
|
nm:str
|
|
dtype:DType
|
|
scalar:bool
|
|
off:Optional[int] = None
|
|
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
|
def subregs(self):
|
|
if self.dtype == dtypes.float.vec(4):
|
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
|
return []
|
|
|
|
class AssemblyInstruction(NamedTuple):
|
|
op: UOps
|
|
out: Optional[Register]
|
|
vin: List[Union[Register, int, float]]
|
|
arg: Any = None
|
|
|
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
|
class AssemblyLanguage:
|
|
supports_load3: bool = False
|
|
sin_is_sin2pi: bool = False
|
|
no_div: bool = False
|
|
#TODO: these should be global vars
|
|
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
|
tor: Dict[Any, Register] = {}
|
|
ins: List[AssemblyInstruction] = []
|
|
|
|
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
|
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
|
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
|
if dtype == dtypes.float.vec(4):
|
|
for off in range(4):
|
|
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
|
self.cnts[(dtype, scalar)] += 1
|
|
return ret
|
|
|
|
def render_numnode(self, b) -> Register:
|
|
key = ("num", b)
|
|
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
|
return self.tor[key]
|
|
|
|
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
|
key = (op, a, b)
|
|
if key not in self.tor:
|
|
#if not isinstance(b, Register): b = render_numnode(b)
|
|
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
|
return self.tor[key]
|
|
|
|
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
|
if a.dtype == new_dtype: return a
|
|
key = (a, new_dtype)
|
|
if key not in self.tor:
|
|
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
|
return self.tor[key]
|
|
|
|
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
|
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
|
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
|
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
|
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
|
|
|
def addr_w_offset(self, args):
|
|
assert isinstance(args, MemOp)
|
|
idx = args.idx*args.memory_dtype.itemsize
|
|
off = 0 # TODO: should this be None?
|
|
if isinstance(idx, SumNode):
|
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
|
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
|
idx -= nums[0]
|
|
off = cast(int, nums[0])
|
|
reg = idx.render(self.render_ops, self)
|
|
if self.supports_load3:
|
|
if reg.scalar:
|
|
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
|
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
|
reg = new_reg
|
|
return self.tor[args.name], reg, off
|
|
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
|
return reg, None, off
|
|
|
|
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
|
#TODO: Do not use clear()
|
|
lang.ins.clear()
|
|
lang.tor.clear()
|
|
lang.cnts.clear()
|
|
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
|
global_size, local_size = [], []
|
|
skipload_branch = 0
|
|
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
|
for u in uops:
|
|
uop,dtype,vin,args,_ = u
|
|
if uop == UOps.DEFINE_LOCAL:
|
|
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
|
elif uop == UOps.LOOP:
|
|
if args[1] == "global":
|
|
for i,var in enumerate(args[0]):
|
|
global_size.append(var.max+1)
|
|
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
|
elif args[1] == "local":
|
|
for i,var in enumerate(args[0]):
|
|
local_size.append(var.max+1)
|
|
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
|
else:
|
|
for var in args[0]:
|
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
|
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
|
elif uop == UOps.ENDLOOP:
|
|
if args[1] not in ["global", "local", "global+local"]:
|
|
for var in reversed(args[0]):
|
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
|
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
|
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
|
elif args[1] == "global+local":
|
|
for i, var in enumerate(reversed(args[0])):
|
|
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
|
elif args[1] == 'local':
|
|
for i, var in enumerate(reversed(args[0])):
|
|
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
|
elif uop == UOps.CAST:
|
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
|
out = lang.newreg(u, dtype)
|
|
for i,sr in enumerate(out.subregs()):
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
|
elif uop == UOps.ALU:
|
|
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
|
# this is the only thing that can violate SSA
|
|
if args in [BinaryOps.CMPLT]:
|
|
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
|
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
|
elif args == BinaryOps.DIV and lang.no_div:
|
|
tmp = lang.newreg((u, "rcp"))
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
|
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
|
tmp = lang.newreg((u, "2pi"))
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
|
else:
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
|
elif uop == UOps.DEFINE_ACC:
|
|
reg = lang.newreg(u, dtype=dtype)
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
|
elif uop == UOps.SPECIAL:
|
|
lang.tor[u] = lang.tor[args]
|
|
elif uop == UOps.CONST:
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
|
elif uop == UOps.LOAD:
|
|
idx, treg, off = lang.addr_w_offset(args)
|
|
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
|
if args.valid.min == 0:
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
|
if args.valid.max == 1:
|
|
pred = args.valid.render(lang.render_ops, lang)
|
|
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
|
if args.valid.max == 1:
|
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
|
if args.valid.min == 0 and args.valid.max == 1:
|
|
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
|
skipload_branch += 1
|
|
elif uop == UOps.STORE:
|
|
if args is None:
|
|
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
|
else:
|
|
idx, treg, off = lang.addr_w_offset(args)
|
|
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
|
|
|
if DEBUG >= 4:
|
|
for tins in lang.ins: print(tins)
|
|
return global_size, local_size
|