Files
tinygrad/extra/assembly/assembly.py
George Hotz d24f936501 just cmplt (#1493)
* just cmplt

* fix maximum

* don't save, there's no backward

* ugh, no slot either

* eq is a scam
2023-08-08 13:58:10 -07:00

183 lines
9.7 KiB
Python

from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
class Register(NamedTuple):
nm:str
dtype:DType
scalar:bool
off:Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes._float4:
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyCodegen(Linearizer):
supports_load3: bool = False
sin_is_sin2pi: bool = False
no_div: bool = False
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
raise NotImplementedError("must be implemented")
# s registers are the addresses and non local indexes
def codegen(self):
self.process()
self.hand_coded_optimizations()
self.limit_global_dims(3) # all GPU asms have 3 (for now)
self.linearize()
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
tor: Dict[Any, Register] = {}
def newreg(tok, dtype=dtypes.float32, scalar=False):
nonlocal cnts, tor
if isinstance(tok, Token): dtype = tok.dtype # this
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
if dtype == dtypes._float4:
for off in range(4):
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
cnts[(dtype, scalar)] += 1
return ret
def render_numnode(b):
key = ("num", b)
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return tor[key]
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in tor:
#if not isinstance(b, Register): b = render_numnode(b)
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return tor[key]
def render_cast(a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in tor:
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
return tor[key]
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def addr_w_offset(args):
idx = args.idx*self.bufs[args.i].dtype.itemsize
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
idx -= nums[0]
off = nums[0]
reg = idx.render(render_ops)
if self.supports_load3:
if reg.scalar:
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return tor[f"buf{args.i}"], reg, off
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
return reg, None, off
ins = []
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
global_size, local_size = [], []
skipload_branch = 0
for uop,newvar,vin,args in self.uops:
if uop == UOps.CONST and newvar is not None:
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
elif uop == UOps.DEFINE_LOCAL:
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif uop == UOps.CAST and newvar is not None:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = newreg(newvar)
for i,sr in enumerate(out.subregs()):
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU and newvar is not None:
out = newreg(newvar) if newvar not in tor else tor[newvar]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and self.no_div:
tmp = newreg((newvar, "rcp"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
tmp = newreg((newvar, "2pi"))
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
elif uop == UOps.LOAD and newvar is not None:
idx, treg, off = addr_w_offset(args)
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
if args.valid.min == 0:
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(render_ops)
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
if args.valid.min == 0 and args.valid.max == 1:
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
idx, treg, off = addr_w_offset(args)
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
# define registers
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
if DEBUG >= 4:
for tins in ins: print(tins)
name, asm = self.specialize(ins)
return ASTRunner(name, asm,
global_size[::-1], local_size[::-1],
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})