mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
* just cmplt * fix maximum * don't save, there's no backward * ugh, no slot either * eq is a scam
183 lines
9.7 KiB
Python
183 lines
9.7 KiB
Python
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict
|
|
from tinygrad.codegen.linearizer import Linearizer, UOps, Token
|
|
from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps
|
|
from tinygrad.helpers import DType, dtypes, DEBUG
|
|
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
|
import functools
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes._float4: 'x'}
|
|
def type_to_letter(x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
|
|
|
class Register(NamedTuple):
|
|
nm:str
|
|
dtype:DType
|
|
scalar:bool
|
|
off:Optional[int] = None
|
|
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
|
def subregs(self):
|
|
if self.dtype == dtypes._float4:
|
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
|
return []
|
|
class AssemblyInstruction(NamedTuple):
|
|
op: UOps
|
|
out: Optional[Register]
|
|
vin: List[Union[Register, int, float]]
|
|
arg: Any = None
|
|
|
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
|
class AssemblyCodegen(Linearizer):
|
|
supports_load3: bool = False
|
|
sin_is_sin2pi: bool = False
|
|
no_div: bool = False
|
|
|
|
def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]:
|
|
raise NotImplementedError("must be implemented")
|
|
|
|
# s registers are the addresses and non local indexes
|
|
def codegen(self):
|
|
self.process()
|
|
self.hand_coded_optimizations()
|
|
self.limit_global_dims(3) # all GPU asms have 3 (for now)
|
|
self.linearize()
|
|
|
|
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
|
tor: Dict[Any, Register] = {}
|
|
def newreg(tok, dtype=dtypes.float32, scalar=False):
|
|
nonlocal cnts, tor
|
|
if isinstance(tok, Token): dtype = tok.dtype # this
|
|
tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar)
|
|
if dtype == dtypes._float4:
|
|
for off in range(4):
|
|
tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
|
cnts[(dtype, scalar)] += 1
|
|
return ret
|
|
|
|
def render_numnode(b):
|
|
key = ("num", b)
|
|
if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
|
return tor[key]
|
|
|
|
def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
|
key = (op, a, b)
|
|
if key not in tor:
|
|
#if not isinstance(b, Register): b = render_numnode(b)
|
|
ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
|
return tor[key]
|
|
|
|
def render_cast(a:Register, new_dtype:DType) -> Register:
|
|
if a.dtype == new_dtype: return a
|
|
key = (a, new_dtype)
|
|
if key not in tor:
|
|
ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a]))
|
|
return tor[key]
|
|
|
|
render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b),
|
|
MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
|
DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
|
ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
|
LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
|
|
|
def addr_w_offset(args):
|
|
idx = args.idx*self.bufs[args.i].dtype.itemsize
|
|
off = 0 # TODO: should this be None?
|
|
if isinstance(idx, SumNode):
|
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
|
if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
|
idx -= nums[0]
|
|
off = nums[0]
|
|
reg = idx.render(render_ops)
|
|
if self.supports_load3:
|
|
if reg.scalar:
|
|
new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
|
ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
|
reg = new_reg
|
|
return tor[f"buf{args.i}"], reg, off
|
|
reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64)
|
|
return reg, None, off
|
|
|
|
ins = []
|
|
ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))]
|
|
global_size, local_size = [], []
|
|
skipload_branch = 0
|
|
for uop,newvar,vin,args in self.uops:
|
|
if uop == UOps.CONST and newvar is not None:
|
|
ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args))
|
|
elif uop == UOps.DEFINE_LOCAL:
|
|
ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
|
ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
|
elif uop == UOps.LOOP:
|
|
if args[1] == "global":
|
|
for i,var in enumerate(args[0]):
|
|
global_size.append(var.max+1)
|
|
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
|
elif args[1] == "local":
|
|
for i,var in enumerate(args[0]):
|
|
local_size.append(var.max+1)
|
|
ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
|
else:
|
|
for var in args[0]:
|
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
|
ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
|
elif uop == UOps.ENDLOOP:
|
|
if args[1] not in ["global", "local", "global+local"]:
|
|
for var in reversed(args[0]):
|
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
|
ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD))
|
|
pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool)
|
|
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
|
elif uop == UOps.CAST and newvar is not None:
|
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
|
out = newreg(newvar)
|
|
for i,sr in enumerate(out.subregs()):
|
|
ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP))
|
|
elif uop == UOps.ALU and newvar is not None:
|
|
out = newreg(newvar) if newvar not in tor else tor[newvar]
|
|
# this is the only thing that can violate SSA
|
|
if args in [BinaryOps.CMPLT]:
|
|
pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool)
|
|
ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args))
|
|
ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
|
elif args == BinaryOps.DIV and self.no_div:
|
|
tmp = newreg((newvar, "rcp"))
|
|
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP))
|
|
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL))
|
|
elif args == UnaryOps.SIN and self.sin_is_sin2pi:
|
|
tmp = newreg((newvar, "2pi"))
|
|
ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
|
ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
|
else:
|
|
ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args))
|
|
elif uop == UOps.LOAD and newvar is not None:
|
|
idx, treg, off = addr_w_offset(args)
|
|
reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype)))
|
|
if args.valid.min == 0:
|
|
ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0))
|
|
if args.valid.max == 1:
|
|
pred = args.valid.render(render_ops)
|
|
ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
|
if args.valid.max == 1:
|
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
|
ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
|
if args.valid.min == 0 and args.valid.max == 1:
|
|
ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
|
skipload_branch += 1
|
|
elif uop == UOps.STORE:
|
|
idx, treg, off = addr_w_offset(args)
|
|
ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared')))
|
|
|
|
# define registers
|
|
ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins
|
|
|
|
if DEBUG >= 4:
|
|
for tins in ins: print(tins)
|
|
name, asm = self.specialize(ins)
|
|
|
|
return ASTRunner(name, asm,
|
|
global_size[::-1], local_size[::-1],
|
|
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True})
|