diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 3e06fbb50a..115bae38d8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -16,13 +16,10 @@ VariableOrNum = Union[Variable, NumNode, Node] # bottom ones are asm only class UOps(Enum): - LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702 + LOOP = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702 DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702 ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702 - # TODO: add CONST. use ALU WHERE for gated load - # *** assembly only UOps *** - SPECIAL = auto(); LABEL = auto(); COND_BRANCH = auto() # TODO: replace these with LOOP and ENDLOOP # noqa: E702 def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: idy = (idxy//(4*base_shape[1])) @@ -67,7 +64,7 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): nli.append(dd % s) dd //= s local_idxs = local_idxs[0:maxdim-1] + nli[::-1] - return local_idxs, loop_local_idxs + return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)] class Linearizer(OptimizedKernel): def get_buffer_name(self, i): @@ -78,9 +75,9 @@ class Linearizer(OptimizedKernel): def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32): render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True) + def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, cachable=True) - render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True), - NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True), + render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b), MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV), ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD), @@ -118,11 +115,10 @@ class Linearizer(OptimizedKernel): assert valid.min == 1 self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const) elif this_const is not None: - self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True) + self.load_cache[key] = self.const(this_const, localtype) if valid.min == 0 and valid.max == 1: valid_rendered = valid.render(self.render_ops, self) - alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True) - self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True) + self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)], TernaryOps.WHERE, cachable=True) else: buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" @@ -133,8 +129,7 @@ class Linearizer(OptimizedKernel): rendered_idx = idx.render(self.render_ops, self) if valid.min == 0: valid_rendered = valid.render(self.render_ops, self) - alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True) - self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, alt]) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)]) else: self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx]) ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key]) @@ -187,6 +182,7 @@ class Linearizer(OptimizedKernel): # uops self.uops: List[UOp] = [] self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs) + self.loop_uops: Dict[str, UOp] = {} # add global buffers arg_bufs = {} @@ -196,7 +192,8 @@ class Linearizer(OptimizedKernel): if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized] # add variables from symbolic shapes for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): - self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32)) + assert var.expr is not None + self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32)) # define local buffers for lb in self.local_alias.values(): self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size())) @@ -226,8 +223,21 @@ class Linearizer(OptimizedKernel): upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] # global and local loops - self.uop(UOps.LOOP, None, [], (loop_global_idxs, "global")) - self.uop(UOps.LOOP, None, [], (loop_local_idxs, "local")) + def render_loop(xx:List[Variable]): + self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, ( + self.const(x.min) if isinstance(x.min, int) else cast(Variable, x.min).render(self.render_ops, self), + self.const(x.max) if isinstance(x.max, int) else cast(Variable, x.max).render(self.render_ops, self))) for x in xx if not isinstance(x, NumNode) and x.expr is not None}) + def end_loop(xx:List[Variable]): + for x in xx[::-1]: + if not isinstance(x, NumNode) and x.expr is not None: + loop_uop = self.loop_uops[x.expr] + if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop]) + + if self.opts.has_local: + self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) + self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) + else: + render_loop(loop_global_idxs+loop_local_idxs) # parse AST loaded_buffers = {} @@ -245,7 +255,7 @@ class Linearizer(OptimizedKernel): acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # reduce loop - self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce")) + render_loop(reduce_idxs) # barrier for fast GEMM if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ()) @@ -314,7 +324,7 @@ class Linearizer(OptimizedKernel): self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, do_reduce=True) # end the reduce loop - self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce")) + end_loop(reduce_idxs) self.load_cache.clear() # end the local loop, do the local reduce @@ -322,7 +332,7 @@ class Linearizer(OptimizedKernel): fake_global_idxs = [x*0 for x in global_idxs] self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators self.uop(UOps.BARRIER, None, [], ()) - self.uop(UOps.ENDLOOP, None, [], (loop_local_idxs, "local")) + end_loop(loop_local_idxs) # local indexs are over, 0 them out local_idxs = [x*0 for x in local_idxs] @@ -343,7 +353,7 @@ class Linearizer(OptimizedKernel): # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] - self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce")) + render_loop(end_local_idxs) # load localbufs loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs+upcast_idxs) @@ -352,7 +362,7 @@ class Linearizer(OptimizedKernel): self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, do_reduce=True) # type: ignore # end the late reduce loop - self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) + end_loop(end_local_idxs) self.load_cache.clear() # load latebufs @@ -365,16 +375,16 @@ class Linearizer(OptimizedKernel): self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) # end the global (and maybe local) loop - self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global")) + end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs) # (recursively) remove childless uops - UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP} + UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.WMMA, UOps.END, UOps.BARRIER, UOps.DEFINE_GLOBAL} while 1: has_child: Set[UOp] = set() for ru in self.uops: for vu in ru.vin: has_child.add(vu) - nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS] + nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS] if len(nu) == len(self.uops): break if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") self.uops = nu diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 11bfb64982..77bbb8c062 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -17,6 +17,7 @@ def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s)) def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) +def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' else fst def merge_dicts(ds:Iterable[Dict]) -> Dict: kvs = set([(k,v) for d in ds for k,v in d.items()]) assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 451dea76b9..f4e01068a2 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -3,13 +3,7 @@ import math from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps -from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType -from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render - -# div is different in cl than python -render_cl = render_python.copy() -render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})" -render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" +from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens class CStyleLanguage(NamedTuple): size_prefix: str = "int" @@ -74,7 +68,7 @@ class CStyleLanguage(NamedTuple): def render_local(self, name:str, size:int): return self.smem_prefix + f"float {name}[{size}];" - def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str: + def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{" def render_conditional(self, cond: str, x:str, y:str) -> str: @@ -98,16 +92,16 @@ class CStyleLanguage(NamedTuple): assert var_dtype == dtypes._float4, "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype == dtypes.float16: - return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" + return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{strip_parens(idx)});" if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" - return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" + return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{strip_parens(idx)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + return f"*({buf_name}+{strip_parens(idx)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]: global_size: List[int] = [] local_size: List[int] = [] kernel,prekernel = [],[] - pend_close = None + #pend_close = None bufs = [] depth = 0 def kk(s): kernel.append(" "*depth+s) @@ -127,31 +121,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: - for i,var in enumerate(args[0]): - if args[1] == "global" and lang.gid: - global_size.append(var.max+1) - kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */") - elif args[1] == "local" and lang.lid: - local_size.append(var.max+1) - kk("{" if isinstance(var, NumNode) else f"{{ {lang.size_prefix} {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */") - else: - if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling - kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max))) + r[u] = ssa('ridx') + kk(lang.render_for(r[u], r[vin[0]], r[vin[1]])) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) - elif uop == UOps.ENDLOOP: - if args[1] == "local" and lang.lid: - # TODO: this is a bit of a hack. the local loop isn't real on the GPU - kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{") - pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */" - else: - if args[1] == "global" and pend_close: - depth -= 1 - kk(pend_close) - pend_close = None - depth -= 1 - kk("}"*len(args[0]) + f" /* {args[1]} */") + elif uop == UOps.END: + depth -= 1 + kk("}") elif uop == UOps.WMMA: if args == "METAL": # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) @@ -175,9 +152,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T assert dtype is not None # remove parens if ALU types are the same. TODO: can do more here if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: - fst = r[vin[0]] - if fst[0] == '(' and fst[-1] == ')': fst = fst[1:-1] - val = lang.code_for_op[args](fst, *[r[x] for x in vin[1:]]) + val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]]) else: val = lang.code_for_op[args](*[r[x] for x in vin]) assert child_count[u] != 0, f"childless ALU op found {u}" @@ -191,7 +166,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T r[u] = ssa('acc') kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};") elif uop == UOps.SPECIAL: - r[u] = args.expr + xid = lang.gid if args[1].startswith("g") else lang.lid + kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]};") + (global_size if args[1].startswith("g") else local_size).append(args[2]) + r[u] = args[1] elif uop == UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop == UOps.LOAD: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 45a32e492a..098efa97ea 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,22 +1,9 @@ from typing import Final, Dict, Callable, Any, List, Optional, Tuple -import functools from llvmlite import ir # type: ignore from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode -def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(32), a) if isinstance(a, int) else a.render(ops, ctx) -render_llvm = { - NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx), - MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), - DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), - ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), - LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)) -} - code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=('fast',)), UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=('fast',)), @@ -87,11 +74,10 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li func.attributes.add('"no-nans-fp-math"="true"') bb = [ir.IRBuilder(func.append_basic_block("entry"))] - loop_blocks = [] + loop_blocks: List = [] reduce_phis: List = [] # TODO: newvar probably shouldn't be optional lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type - render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] for bufname,dtype in buf_to_dtype.items(): if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) @@ -99,30 +85,26 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: - for var in args[0]: - if isinstance(var, NumNode): continue - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{var.expr}"))) - bb[-2].branch(bb[-1]._block) + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) + bb[-2].branch(bb[-1]._block) - phis = [] - for rp in reduce_phis: - incoming = lvars[rp] - lvars[rp] = bb[-1].phi(ir.FloatType()) - lvars[rp].add_incoming(incoming, bb[-2]._block) - phis.append((rp, lvars[rp])) - loop_blocks.append((bb[-1], phis)) + phis = [] + for rp in reduce_phis: + incoming = lvars[rp] + lvars[rp] = bb[-1].phi(ir.FloatType()) + lvars[rp].add_incoming(incoming, bb[-2]._block) + phis.append((rp, lvars[rp])) - lvars[var.expr] = bb[-1].phi(ir.IntType(32), name=var.expr) - lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block) - if uop == UOps.ENDLOOP: - for var in args[0][::-1]: - if isinstance(var, NumNode): continue - block, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[var.expr], sym_render(1)) - lvars[var.expr].add_incoming(idx_p1, bb[-1]._block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{var.expr}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, sym_render(var.max, render_llvm, bb[-2])), bb[-1]._block, block._block) + lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") + lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block) + loop_blocks.append((bb[-1], phis)) + if uop == UOps.END: + block, phis = loop_blocks.pop() + idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1)) + lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block) + for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) + bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) + bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, lvars[vin[0].vin[1]]), bb[-1]._block, block._block) if uop == UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args[0]]] if uop == UOps.DEFINE_ACC: @@ -137,7 +119,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li assert dtype is not None if len(vin) > 2: gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1)) - aug_idx = bb[-1].select(gate, lvars[vin[1]], sym_render(0)) + aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0)) val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True)) val = cast(bb, val, vin[0].dtype, dtype) val = bb[-1].select(gate, val, lvars[vin[3]]) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index dbaf571fab..f1820f9056 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -40,7 +40,7 @@ class WGSLLanguage(CStyleLanguage): prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" return prg, global_size[::-1] if global_size else [1], local_size - def render_for(self, expr:str, _min:int, _max:Union[int,str]) -> str: + def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{" def render_conditional(self, cond:str, x:str, y:str) -> str: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 9cad0fd19b..153c73ccf0 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -15,12 +15,10 @@ class Node(ABC): b: Union[Node, int] min: int max: int - def render(self, ops=None, ctx=None, strip_parens=False) -> str: + def render(self, ops=None, ctx=None) -> str: if ops is None: ops = render_python assert self.__class__ in (Variable, NumNode) or self.min != self.max - ret = ops[type(self)](self, ops, ctx) - if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1] - return ret + return ops[type(self)](self, ops, ctx) def vars(self): return [] # expand a Node into List[Node] that enumerates the underlying Variables from min to max def expand(self) -> List[Node]: raise NotImplementedError(self.__class__.__name__)