removing loop (#1764)

* removing loop

* fix llvm

* remove unused

* strip parens

* with side effects

* define global has side effects
This commit is contained in:
George Hotz
2023-09-04 14:47:46 -07:00
committed by GitHub
parent 7344f7c2d1
commit b32ed8e6e9
6 changed files with 73 additions and 104 deletions

View File

@@ -16,13 +16,10 @@ VariableOrNum = Union[Variable, NumNode, Node]
# bottom ones are asm only
class UOps(Enum):
LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702
LOOP = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
# TODO: add CONST. use ALU WHERE for gated load
# *** assembly only UOps ***
SPECIAL = auto(); LABEL = auto(); COND_BRANCH = auto() # TODO: replace these with LOOP and ENDLOOP # noqa: E702
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
idy = (idxy//(4*base_shape[1]))
@@ -67,7 +64,7 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
nli.append(dd % s)
dd //= s
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
return local_idxs, loop_local_idxs
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
class Linearizer(OptimizedKernel):
def get_buffer_name(self, i):
@@ -78,9 +75,9 @@ class Linearizer(OptimizedKernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True)
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, cachable=True)
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True),
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True),
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
@@ -118,11 +115,10 @@ class Linearizer(OptimizedKernel):
assert valid.min == 1
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const)
elif this_const is not None:
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True)
self.load_cache[key] = self.const(this_const, localtype)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True)
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)], TernaryOps.WHERE, cachable=True)
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
@@ -133,8 +129,7 @@ class Linearizer(OptimizedKernel):
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self)
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, alt])
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)])
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx])
ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
@@ -187,6 +182,7 @@ class Linearizer(OptimizedKernel):
# uops
self.uops: List[UOp] = []
self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
arg_bufs = {}
@@ -196,7 +192,8 @@ class Linearizer(OptimizedKernel):
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32))
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32))
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size()))
@@ -226,8 +223,21 @@ class Linearizer(OptimizedKernel):
upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]]
# global and local loops
self.uop(UOps.LOOP, None, [], (loop_global_idxs, "global"))
self.uop(UOps.LOOP, None, [], (loop_local_idxs, "local"))
def render_loop(xx:List[Variable]):
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Variable, x.min).render(self.render_ops, self),
self.const(x.max) if isinstance(x.max, int) else cast(Variable, x.max).render(self.render_ops, self))) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:
loop_uop = self.loop_uops[x.expr]
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop])
if self.opts.has_local:
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
render_loop(loop_global_idxs+loop_local_idxs)
# parse AST
loaded_buffers = {}
@@ -245,7 +255,7 @@ class Linearizer(OptimizedKernel):
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# reduce loop
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
render_loop(reduce_idxs)
# barrier for fast GEMM
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
@@ -314,7 +324,7 @@ class Linearizer(OptimizedKernel):
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, do_reduce=True)
# end the reduce loop
self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
end_loop(reduce_idxs)
self.load_cache.clear()
# end the local loop, do the local reduce
@@ -322,7 +332,7 @@ class Linearizer(OptimizedKernel):
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, [], ())
self.uop(UOps.ENDLOOP, None, [], (loop_local_idxs, "local"))
end_loop(loop_local_idxs)
# local indexs are over, 0 them out
local_idxs = [x*0 for x in local_idxs]
@@ -343,7 +353,7 @@ class Linearizer(OptimizedKernel):
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
render_loop(end_local_idxs)
# load localbufs
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs)
@@ -352,7 +362,7 @@ class Linearizer(OptimizedKernel):
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, do_reduce=True) # type: ignore
# end the late reduce loop
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
end_loop(end_local_idxs)
self.load_cache.clear()
# load latebufs
@@ -365,16 +375,16 @@ class Linearizer(OptimizedKernel):
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the global (and maybe local) loop
self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global"))
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
# (recursively) remove childless uops
UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP}
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS]
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS]
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu

View File

@@ -17,6 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst
def merge_dicts(ds:Iterable[Dict]) -> Dict:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"

View File

@@ -3,13 +3,7 @@ import math
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"
@@ -74,7 +68,7 @@ class CStyleLanguage(NamedTuple):
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
@@ -98,16 +92,16 @@ class CStyleLanguage(NamedTuple):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
global_size: List[int] = []
local_size: List[int] = []
kernel,prekernel = [],[]
pend_close = None
#pend_close = None
bufs = []
depth = 0
def kk(s): kernel.append(" "*depth+s)
@@ -127,31 +121,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
for i,var in enumerate(args[0]):
if args[1] == "global" and lang.gid:
global_size.append(var.max+1)
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
elif args[1] == "local" and lang.lid:
local_size.append(var.max+1)
kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
else:
if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
r[u] = ssa('ridx')
kk(lang.render_for(r[u], r[vin[0]], r[vin[1]]))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.ENDLOOP:
if args[1] == "local" and lang.lid:
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
else:
if args[1] == "global" and pend_close:
depth -= 1
kk(pend_close)
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
elif uop == UOps.END:
depth -= 1
kk("}")
elif uop == UOps.WMMA:
if args == "METAL":
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
@@ -175,9 +152,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
assert dtype is not None
# remove parens if ALU types are the same. TODO: can do more here
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
fst = r[vin[0]]
if fst[0] == '(' and fst[-1] == ')': fst = fst[1:-1]
val = lang.code_for_op[args](fst, *[r[x] for x in vin[1:]])
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
else:
val = lang.code_for_op[args](*[r[x] for x in vin])
assert child_count[u] != 0, f"childless ALU op found {u}"
@@ -191,7 +166,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
r[u] = ssa('acc')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL:
r[u] = args.expr
xid = lang.gid if args[1].startswith("g") else lang.lid
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};")
(global_size if args[1].startswith("g") else local_size).append(args[2])
r[u] = args[1]
elif uop == UOps.CONST:
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop == UOps.LOAD:

View File

@@ -1,22 +1,9 @@
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
import functools
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(32), a) if isinstance(a, int) else a.render(ops, ctx)
render_llvm = {
NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx),
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
}
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=('fast',)),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)),
@@ -87,11 +74,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
func.attributes.add('"no-nans-fp-math"="true"')
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks = []
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr]
for bufname,dtype in buf_to_dtype.items():
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
@@ -99,30 +85,26 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
for var in args[0]:
if isinstance(var, NumNode): continue
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}")))
bb[-2].branch(bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1]._block)
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(ir.FloatType())
lvars[rp].add_incoming(incoming, bb[-2]._block)
phis.append((rp, lvars[rp]))
loop_blocks.append((bb[-1], phis))
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(ir.FloatType())
lvars[rp].add_incoming(incoming, bb[-2]._block)
phis.append((rp, lvars[rp]))
lvars[var.expr] = bb[-1].phi(ir.IntType(32), name=var.expr)
lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block)
if uop == UOps.ENDLOOP:
for var in args[0][::-1]:
if isinstance(var, NumNode): continue
block, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1))
lvars[var.expr].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}")))
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block)
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
loop_blocks.append((bb[-1], phis))
if uop == UOps.END:
block, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, lvars[vin[0].vin[1]]), bb[-1]._block, block._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
if uop == UOps.DEFINE_ACC:
@@ -137,7 +119,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
assert dtype is not None
if len(vin) > 2:
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
aug_idx = bb[-1].select(gate, lvars[vin[1]], sym_render(0))
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = cast(bb, val, vin[0].dtype, dtype)
val = bb[-1].select(gate, val, lvars[vin[3]])

View File

@@ -40,7 +40,7 @@ class WGSLLanguage(CStyleLanguage):
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg, global_size[::-1] if global_size else [1], local_size
def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str:
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:

View File

@@ -15,12 +15,10 @@ class Node(ABC):
b: Union[Node, int]
min: int
max: int
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
def render(self, ops=None, ctx=None) -> str:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
ret = ops[type(self)](self, ops, ctx)
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
return ret
return ops[type(self)](self, ops, ctx)
def vars(self): return []
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
def expand(self) -> List[Node]: raise NotImplementedError(self.__class__.__name__)