Fix compile error for max with inline const (#5840)

This commit is contained in:
ignaciosica
2024-08-12 12:40:39 -03:00
committed by GitHub
parent e9024c691f
commit 777d6b3349
2 changed files with 4 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ import unittest
from typing import List, cast
import numpy as np
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.device import Buffer, CompileError, Device
from tinygrad.device import Buffer, Device
from tinygrad.dtype import PtrDType, DType, dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten
@@ -36,12 +36,8 @@ class TestCStyleFailures(unittest.TestCase):
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)))
store = UOp.store(a, idx, alu)
# CLANG doesn't use the max function
if Device.DEFAULT in ["CLANG"]:
ret = _test_uop_result([Tensor([1])], [store])[0]
assert ret[0] == 1
# call to 'max' is ambiguous
else:
with self.assertRaises(CompileError): _test_uop_result([Tensor([1])], [store])
ret = _test_uop_result([Tensor([1])], [store])[0]
self.assertEqual(ret[0], 1)
if __name__ == '__main__':
unittest.main()

View File

@@ -137,6 +137,7 @@ class CStyleLanguage(Renderer):
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], cast(DType, v.dtype)) if v.op is UOps.CONST else r[v] for v in src]
else: operands = [r[v] for v in src]
val = self.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"