mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
removing loop (#1764)
* removing loop * fix llvm * remove unused * strip parens * with side effects * define global has side effects
This commit is contained in:
@@ -16,13 +16,10 @@ VariableOrNum = Union[Variable, NumNode, Node]
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum):
|
||||
LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702
|
||||
LOOP = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
# TODO: add CONST. use ALU WHERE for gated load
|
||||
# *** assembly only UOps ***
|
||||
SPECIAL = auto(); LABEL = auto(); COND_BRANCH = auto() # TODO: replace these with LOOP and ENDLOOP # noqa: E702
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
@@ -67,7 +64,7 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
|
||||
nli.append(dd % s)
|
||||
dd //= s
|
||||
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
|
||||
return local_idxs, loop_local_idxs
|
||||
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
|
||||
|
||||
class Linearizer(OptimizedKernel):
|
||||
def get_buffer_name(self, i):
|
||||
@@ -78,9 +75,9 @@ class Linearizer(OptimizedKernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True)
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, cachable=True)
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True),
|
||||
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True),
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
@@ -118,11 +115,10 @@ class Linearizer(OptimizedKernel):
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True)
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)], TernaryOps.WHERE, cachable=True)
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
@@ -133,8 +129,7 @@ class Linearizer(OptimizedKernel):
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
if valid.min == 0:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, alt])
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)])
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx])
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
|
||||
@@ -187,6 +182,7 @@ class Linearizer(OptimizedKernel):
|
||||
# uops
|
||||
self.uops: List[UOp] = []
|
||||
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
|
||||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
arg_bufs = {}
|
||||
@@ -196,7 +192,8 @@ class Linearizer(OptimizedKernel):
|
||||
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
|
||||
self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32))
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
@@ -226,8 +223,21 @@ class Linearizer(OptimizedKernel):
|
||||
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
|
||||
|
||||
# global and local loops
|
||||
self.uop(UOps.LOOP, None, [], (loop_global_idxs, "global"))
|
||||
self.uop(UOps.LOOP, None, [], (loop_local_idxs, "local"))
|
||||
def render_loop(xx:List[Variable]):
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Variable, x.min).render(self.render_ops, self),
|
||||
self.const(x.max) if isinstance(x.max, int) else cast(Variable, x.max).render(self.render_ops, self))) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
||||
def end_loop(xx:List[Variable]):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop])
|
||||
|
||||
if self.opts.has_local:
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
@@ -245,7 +255,7 @@ class Linearizer(OptimizedKernel):
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
|
||||
|
||||
# reduce loop
|
||||
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
|
||||
render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
|
||||
@@ -314,7 +324,7 @@ class Linearizer(OptimizedKernel):
|
||||
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, do_reduce=True)
|
||||
|
||||
# end the reduce loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
|
||||
end_loop(reduce_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
@@ -322,7 +332,7 @@ class Linearizer(OptimizedKernel):
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.ENDLOOP, None, [], (loop_local_idxs, "local"))
|
||||
end_loop(loop_local_idxs)
|
||||
|
||||
# local indexs are over, 0 them out
|
||||
local_idxs = [x*0 for x in local_idxs]
|
||||
@@ -343,7 +353,7 @@ class Linearizer(OptimizedKernel):
|
||||
|
||||
# late reduce loop
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
|
||||
@@ -352,7 +362,7 @@ class Linearizer(OptimizedKernel):
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, do_reduce=True) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
end_loop(end_local_idxs)
|
||||
self.load_cache.clear()
|
||||
|
||||
# load latebufs
|
||||
@@ -365,16 +375,16 @@ class Linearizer(OptimizedKernel):
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# end the global (and maybe local) loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global"))
|
||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
||||
|
||||
# (recursively) remove childless uops
|
||||
UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP}
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS]
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
@@ -17,6 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
|
||||
@@ -3,13 +3,7 @@ import math
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
|
||||
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
size_prefix: str = "int"
|
||||
@@ -74,7 +68,7 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
@@ -98,16 +92,16 @@ class CStyleLanguage(NamedTuple):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
|
||||
global_size: List[int] = []
|
||||
local_size: List[int] = []
|
||||
kernel,prekernel = [],[]
|
||||
pend_close = None
|
||||
#pend_close = None
|
||||
bufs = []
|
||||
depth = 0
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
@@ -127,31 +121,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
for i,var in enumerate(args[0]):
|
||||
if args[1] == "global" and lang.gid:
|
||||
global_size.append(var.max+1)
|
||||
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
elif args[1] == "local" and lang.lid:
|
||||
local_size.append(var.max+1)
|
||||
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
else:
|
||||
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
|
||||
r[u] = ssa('ridx')
|
||||
kk(lang.render_for(r[u], r[vin[0]], r[vin[1]]))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] == "local" and lang.lid:
|
||||
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
||||
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
|
||||
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
|
||||
else:
|
||||
if args[1] == "global" and pend_close:
|
||||
depth -= 1
|
||||
kk(pend_close)
|
||||
pend_close = None
|
||||
depth -= 1
|
||||
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
||||
elif uop == UOps.END:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop == UOps.WMMA:
|
||||
if args == "METAL":
|
||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||
@@ -175,9 +152,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
assert dtype is not None
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
|
||||
fst = r[vin[0]]
|
||||
if fst[0] == '(' and fst[-1] == ')': fst = fst[1:-1]
|
||||
val = lang.code_for_op[args](fst, *[r[x] for x in vin[1:]])
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
|
||||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
@@ -191,7 +166,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
r[u] = ssa('acc')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
r[u] = args.expr
|
||||
xid = lang.gid if args[1].startswith("g") else lang.lid
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};")
|
||||
(global_size if args[1].startswith("g") else local_size).append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CONST:
|
||||
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop == UOps.LOAD:
|
||||
|
||||
@@ -1,22 +1,9 @@
|
||||
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
|
||||
import functools
|
||||
from llvmlite import ir # type: ignore
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(32), a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
render_llvm = {
|
||||
NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx),
|
||||
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
||||
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
||||
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
||||
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
||||
}
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
@@ -87,11 +74,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
func.attributes.add('"no-nans-fp-math"="true"')
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks = []
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
# TODO: newvar probably shouldn't be optional
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
@@ -99,30 +85,26 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
for var in args[0]:
|
||||
if isinstance(var, NumNode): continue
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}")))
|
||||
bb[-2].branch(bb[-1]._block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1]._block)
|
||||
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(ir.FloatType())
|
||||
lvars[rp].add_incoming(incoming, bb[-2]._block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
loop_blocks.append((bb[-1], phis))
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(ir.FloatType())
|
||||
lvars[rp].add_incoming(incoming, bb[-2]._block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
|
||||
lvars[var.expr] = bb[-1].phi(ir.IntType(32), name=var.expr)
|
||||
lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block)
|
||||
if uop == UOps.ENDLOOP:
|
||||
for var in args[0][::-1]:
|
||||
if isinstance(var, NumNode): continue
|
||||
block, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1))
|
||||
lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block)
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
|
||||
loop_blocks.append((bb[-1], phis))
|
||||
if uop == UOps.END:
|
||||
block, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, lvars[vin[0].vin[1]]), bb[-1]._block, block._block)
|
||||
if uop == UOps.DEFINE_GLOBAL:
|
||||
lvars[u] = func.args[buf_index[args[0]]]
|
||||
if uop == UOps.DEFINE_ACC:
|
||||
@@ -137,7 +119,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
||||
assert dtype is not None
|
||||
if len(vin) > 2:
|
||||
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
|
||||
aug_idx = bb[-1].select(gate, lvars[vin[1]], sym_render(0))
|
||||
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
||||
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
||||
val = cast(bb, val, vin[0].dtype, dtype)
|
||||
val = bb[-1].select(gate, val, lvars[vin[3]])
|
||||
|
||||
@@ -40,7 +40,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
return prg, global_size[::-1] if global_size else [1], local_size
|
||||
|
||||
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
|
||||
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
|
||||
@@ -15,12 +15,10 @@ class Node(ABC):
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
|
||||
def render(self, ops=None, ctx=None) -> str:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
ret = ops[type(self)](self, ops, ctx)
|
||||
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
|
||||
return ret
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
||||
def expand(self) -> List[Node]: raise NotImplementedError(self.__class__.__name__)
|
||||
|
||||
Reference in New Issue
Block a user