fixing branch condition on UOps.IF in the ptx renderer (#7315)

* fixing branch condition on UOps.IF in the ptx renderer

* ptx works

---------

Co-authored-by: Nick Talati <nick.talati@quantworks.com>
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
talati
2024-10-27 08:27:38 -04:00
committed by GitHub
parent a410b46c1d
commit d4d201d87b
2 changed files with 6 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ from tinygrad.codegen.linearize import linearize_uop
from tinygrad.device import Buffer, Device
from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, getenv, prod
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, UOps
from tinygrad.renderer import Program
@@ -62,11 +62,7 @@ class TestPTXFailures(unittest.TestCase):
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
if getenv("PTX"):
with self.assertRaises(AssertionError):
np.testing.assert_equal(ret, [0, 1, 1, 1])
else: np.testing.assert_equal(ret, [0, 1, 1, 1])
np.testing.assert_equal(ret, [0, 1, 1, 1])
if __name__ == '__main__':
unittest.main()