mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
use "<" instead of "<=" in codegen for loop (#2027)
This commit is contained in:
@@ -198,8 +198,8 @@ class Linearizer(OptimizedKernel):
|
||||
# global and local loops
|
||||
def render_loop(xx:List[Variable]):
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Variable, x.min).render(self.render_ops, self),
|
||||
self.const(x.max) if isinstance(x.max, int) else cast(Variable, x.max).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None})
|
||||
def end_loop(xx:List[Variable]):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
|
||||
@@ -74,7 +74,7 @@ class CStyleLanguage(NamedTuple):
|
||||
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
||||
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
|
||||
@@ -106,7 +106,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned(">", idx_p1, lvars[vin[0].vin[1]]), bb[-1]._block, block._block)
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
|
||||
if uop == UOps.DEFINE_GLOBAL:
|
||||
lvars[u] = func.args[buf_index[args[0]]]
|
||||
if uop == UOps.DEFINE_ACC:
|
||||
|
||||
@@ -41,7 +41,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||
return prg
|
||||
|
||||
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select(f32({y}), {x}, bool({cond}))"
|
||||
|
||||
Reference in New Issue
Block a user