use "<" instead of "<=" in codegen for loop (#2027)

This commit is contained in:
chenyu
2023-10-09 20:26:36 -04:00
committed by GitHub
parent 25555c836f
commit 45f0891a8f
4 changed files with 5 additions and 5 deletions

View File

@@ -198,8 +198,8 @@ class Linearizer(OptimizedKernel):
# global and local loops
def render_loop(xx:List[Variable]):
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Variable, x.min).render(self.render_ops, self),
self.const(x.max) if isinstance(x.max, int) else cast(Variable, x.max).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
def end_loop(xx:List[Variable]):
for x in xx[::-1]:
if not isinstance(x, NumNode) and x.expr is not None:

View File

@@ -74,7 +74,7 @@ class CStyleLanguage(NamedTuple):
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"

View File

@@ -106,7 +106,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, lvars[vin[0].vin[1]]), bb[-1]._block, block._block)
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
if uop == UOps.DEFINE_ACC:

View File

@@ -41,7 +41,7 @@ class WGSLLanguage(CStyleLanguage):
return prg
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"