diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index e709f653ac..fafd30d3d8 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -6,7 +6,7 @@ from tinygrad.codegen.linearize import linearize_uop from tinygrad.device import Buffer, Device from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner -from tinygrad.helpers import dedup, flatten, getenv, prod +from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.ops import BinaryOps, UOp, UOps from tinygrad.renderer import Program @@ -62,11 +62,7 @@ class TestPTXFailures(unittest.TestCase): sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] - - if getenv("PTX"): - with self.assertRaises(AssertionError): - np.testing.assert_equal(ret, [0, 1, 1, 1]) - else: np.testing.assert_equal(ret, [0, 1, 1, 1]) + np.testing.assert_equal(ret, [0, 1, 1, 1]) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index e27e2fdf26..1677cee330 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -96,7 +96,8 @@ class PTXRenderer(Renderer): def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] - def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"] + def render_bra(self, b1, pred=None, invert=False) -> List[str]: + return [f"@{'!' if invert else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"] def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: assert dtype != dtypes.bool @@ -153,7 +154,8 @@ class PTXRenderer(Renderer): for u in uops: uop,dtype,src,args = u.op,u.dtype,u.src,u.arg if uop is UOps.IF: - kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True))) + pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True) + kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True)) elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) elif uop is UOps.ENDRANGE: kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),