mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fixing branch condition on UOps.IF in the ptx renderer (#7315)
* fixing branch condition on UOps.IF in the ptx renderer * ptx works --------- Co-authored-by: Nick Talati <nick.talati@quantworks.com> Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, getenv, prod
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
from tinygrad.renderer import Program
|
||||
@@ -62,11 +62,7 @@ class TestPTXFailures(unittest.TestCase):
|
||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
|
||||
if getenv("PTX"):
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
else: np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user